"""Cutom encoder definition for transducer models."""

from typing import List, Tuple, Union

import torch

from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling

[docs]class CustomEncoder(torch.nn.Module): """Custom encoder module for transducer models. Args: idim: Input dimension. enc_arch: Encoder block architecture (type and parameters). input_layer: Input layer type. repeat_block: Number of times blocks_arch is repeated. self_attn_type: Self-attention type. positional_encoding_type: Positional encoding type. positionwise_layer_type: Positionwise layer type. positionwise_activation_type: Positionwise activation type. conv_mod_activation_type: Convolutional module activation type. aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences. input_layer_dropout_rate: Dropout rate for input layer. input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc. padding_idx: Padding symbol ID for embedding layer. """ def __init__( self, idim: int, enc_arch: List, input_layer: str = "linear", repeat_block: int = 1, self_attn_type: str = "selfattn", positional_encoding_type: str = "abs_pos", positionwise_layer_type: str = "linear", positionwise_activation_type: str = "relu", conv_mod_activation_type: str = "relu", aux_enc_output_layers: List = [], input_layer_dropout_rate: float = 0.0, input_layer_pos_enc_dropout_rate: float = 0.0, padding_idx: int = -1, ): """Construct an CustomEncoder object.""" super().__init__() ( self.embed, self.encoders, self.enc_out, self.conv_subsampling_factor, ) = build_blocks( "encoder", idim, input_layer, enc_arch, repeat_block=repeat_block, self_attn_type=self_attn_type, positional_encoding_type=positional_encoding_type, positionwise_layer_type=positionwise_layer_type, positionwise_activation_type=positionwise_activation_type, conv_mod_activation_type=conv_mod_activation_type, input_layer_dropout_rate=input_layer_dropout_rate, input_layer_pos_enc_dropout_rate=input_layer_pos_enc_dropout_rate, padding_idx=padding_idx, ) self.after_norm = LayerNorm(self.enc_out) self.n_blocks = len(enc_arch) * repeat_block self.aux_enc_output_layers = aux_enc_output_layers
[docs] def forward( self, feats: torch.Tensor, mask: torch.Tensor, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: """Encode feature sequences. Args: feats: Feature sequences. (B, F, D_feats) feats_mask: Feature mask sequences. (B, 1, F) Returns: enc_out: Encoder output sequences. (B, T, D_enc) with/without Auxiliary encoder output sequences. (B, T, D_enc_aux) enc_out_mask: Mask for encoder output sequences. (B, 1, T) with/without Mask for auxiliary encoder output sequences. (B, T, D_enc_aux) """ if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): enc_out, mask = self.embed(feats, mask) else: enc_out = self.embed(feats) if self.aux_enc_output_layers: aux_custom_outputs = [] aux_custom_lens = [] for b in range(self.n_blocks): enc_out, mask = self.encoders[b](enc_out, mask) if b in self.aux_enc_output_layers: if isinstance(enc_out, tuple): aux_custom_output = enc_out[0] else: aux_custom_output = enc_out aux_custom_outputs.append(self.after_norm(aux_custom_output)) aux_custom_lens.append(mask) else: enc_out, mask = self.encoders(enc_out, mask) if isinstance(enc_out, tuple): enc_out = enc_out[0] enc_out = self.after_norm(enc_out) if self.aux_enc_output_layers: return (enc_out, aux_custom_outputs), (mask, aux_custom_lens) return enc_out, mask