Source code for espnet2.spk.encoder.ecapa_tdnn_encoder

# Copyright 2023 Jee-weon Jung
# Apache 2.0


import torch
import torch.nn as nn
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.spk.layers.ecapa_block import EcapaBlock

[docs]class EcapaTdnnEncoder(AbsEncoder): """ECAPA-TDNN encoder. Extracts frame-level ECAPA-TDNN embeddings from mel-filterbank energy or MFCC features. Paper: B Desplanques at el., ``ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification,'' in Proc. INTERSPEECH, 2020. Args: input_size: input feature dimension. block: type of encoder block class to use. model_scale: scale value of the Res2Net architecture. ndim: dimensionality of the hidden representation. output_size: output embedding dimension. """ @typechecked def __init__( self, input_size: int, block: str = "EcapaBlock", model_scale: int = 8, ndim: int = 1024, output_size: int = 1536, **kwargs, ): super().__init__() if block == "EcapaBlock": block: type = EcapaBlock else: raise ValueError(f"unsupported block, got: {block}") self._output_size = output_size self.conv = nn.Conv1d(input_size, ndim, kernel_size=5, stride=1, padding=2) self.relu = nn.ReLU() = nn.BatchNorm1d(ndim) self.layer1 = block(ndim, ndim, kernel_size=3, dilation=2, scale=model_scale) self.layer2 = block(ndim, ndim, kernel_size=3, dilation=3, scale=model_scale) self.layer3 = block(ndim, ndim, kernel_size=3, dilation=4, scale=model_scale) self.layer4 = nn.Conv1d(3 * ndim, output_size, kernel_size=1) self.mp3 = nn.MaxPool1d(3)
[docs] def output_size(self) -> int: return self._output_size
[docs] def forward(self, x: torch.Tensor): """Calculate forward propagation. Args: x (torch.Tensor): Input tensor (#batch, L, input_size). Returns: torch.Tensor: Output tensor (#batch, L, output_size). """ x = self.conv(x.permute(0, 2, 1)) x = self.relu(x) x = x1 = self.layer1(x) x2 = self.layer2(x + x1) x3 = self.layer3(x + x1 + x2) x = self.layer4(, x2, x3), dim=1)) x = self.relu(x) return x