"""Cutom encoder definition for transducer models."""
from typing import List, Tuple, Union
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
[docs]class CustomEncoder(torch.nn.Module):
"""Custom encoder module for transducer models.
Args:
idim: Input dimension.
enc_arch: Encoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times blocks_arch is repeated.
self_attn_type: Self-attention type.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences.
input_layer_dropout_rate: Dropout rate for input layer.
input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
padding_idx: Padding symbol ID for embedding layer.
"""
def __init__(
self,
idim: int,
enc_arch: List,
input_layer: str = "linear",
repeat_block: int = 1,
self_attn_type: str = "selfattn",
positional_encoding_type: str = "abs_pos",
positionwise_layer_type: str = "linear",
positionwise_activation_type: str = "relu",
conv_mod_activation_type: str = "relu",
aux_enc_output_layers: List = [],
input_layer_dropout_rate: float = 0.0,
input_layer_pos_enc_dropout_rate: float = 0.0,
padding_idx: int = -1,
):
"""Construct an CustomEncoder object."""
super().__init__()
(
self.embed,
self.encoders,
self.enc_out,
self.conv_subsampling_factor,
) = build_blocks(
"encoder",
idim,
input_layer,
enc_arch,
repeat_block=repeat_block,
self_attn_type=self_attn_type,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
conv_mod_activation_type=conv_mod_activation_type,
input_layer_dropout_rate=input_layer_dropout_rate,
input_layer_pos_enc_dropout_rate=input_layer_pos_enc_dropout_rate,
padding_idx=padding_idx,
)
self.after_norm = LayerNorm(self.enc_out)
self.n_blocks = len(enc_arch) * repeat_block
self.aux_enc_output_layers = aux_enc_output_layers
[docs] def forward(
self,
feats: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]:
"""Encode feature sequences.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_mask: Feature mask sequences. (B, 1, F)
Returns:
enc_out: Encoder output sequences. (B, T, D_enc) with/without
Auxiliary encoder output sequences. (B, T, D_enc_aux)
enc_out_mask: Mask for encoder output sequences. (B, 1, T) with/without
Mask for auxiliary encoder output sequences. (B, T, D_enc_aux)
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
enc_out, mask = self.embed(feats, mask)
else:
enc_out = self.embed(feats)
if self.aux_enc_output_layers:
aux_custom_outputs = []
aux_custom_lens = []
for b in range(self.n_blocks):
enc_out, mask = self.encoders[b](enc_out, mask)
if b in self.aux_enc_output_layers:
if isinstance(enc_out, tuple):
aux_custom_output = enc_out[0]
else:
aux_custom_output = enc_out
aux_custom_outputs.append(self.after_norm(aux_custom_output))
aux_custom_lens.append(mask)
else:
enc_out, mask = self.encoders(enc_out, mask)
if isinstance(enc_out, tuple):
enc_out = enc_out[0]
enc_out = self.after_norm(enc_out)
if self.aux_enc_output_layers:
return (enc_out, aux_custom_outputs), (mask, aux_custom_lens)
return enc_out, mask