Source code for espnet2.asr_transducer.encoder.building

"""Set of methods to build Transducer encoder architecture."""

from typing import Any, Dict, List, Optional, Union

from espnet2.asr_transducer.activation import get_activation
from espnet2.asr_transducer.encoder.blocks.branchformer import Branchformer
from espnet2.asr_transducer.encoder.blocks.conformer import Conformer
from espnet2.asr_transducer.encoder.blocks.conv1d import Conv1d
from espnet2.asr_transducer.encoder.blocks.conv_input import ConvInput
from espnet2.asr_transducer.encoder.blocks.ebranchformer import EBranchformer
from espnet2.asr_transducer.encoder.modules.attention import (  # noqa: H301
    RelPositionMultiHeadedAttention,
)
from espnet2.asr_transducer.encoder.modules.convolution import (  # noqa: H301
    ConformerConvolution,
    ConvolutionalSpatialGatingUnit,
    DepthwiseConvolution,
)
from espnet2.asr_transducer.encoder.modules.multi_blocks import MultiBlocks
from espnet2.asr_transducer.encoder.modules.positional_encoding import (  # noqa: H301
    RelPositionalEncoding,
)
from espnet2.asr_transducer.normalization import get_normalization
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,
)


[docs]def build_main_parameters( pos_wise_act_type: str = "swish", conv_mod_act_type: str = "swish", pos_enc_dropout_rate: float = 0.0, pos_enc_max_len: int = 5000, simplified_att_score: bool = False, norm_type: str = "layer_norm", conv_mod_norm_type: str = "layer_norm", after_norm_eps: Optional[float] = None, after_norm_partial: Optional[float] = None, blockdrop_rate: float = 0.0, dynamic_chunk_training: bool = False, short_chunk_threshold: float = 0.75, short_chunk_size: int = 25, num_left_chunks: int = 0, **activation_parameters, ) -> Dict[str, Any]: """Build encoder main parameters. Args: pos_wise_act_type: X-former position-wise feed-forward activation type. conv_mod_act_type: X-former convolution module activation type. pos_enc_dropout_rate: Positional encoding dropout rate. pos_enc_max_len: Positional encoding maximum length. simplified_att_score: Whether to use simplified attention score computation. norm_type: X-former normalization module type. conv_mod_norm_type: Conformer convolution module normalization type. after_norm_eps: Epsilon value for the final normalization. after_norm_partial: Value for the final normalization with RMSNorm. blockdrop_rate: Probability threshold of dropping out each encoder block. dynamic_chunk_training: Whether to use dynamic chunk training. short_chunk_threshold: Threshold for dynamic chunk selection. short_chunk_size: Minimum number of frames during dynamic chunk training. num_left_chunks: Number of left chunks the attention module can see. (null or negative value means full context) **activation_parameters: Parameters of the activation functions. (See espnet2/asr_transducer/activation.py) Returns: : Main encoder parameters """ main_params = {} main_params["pos_wise_act"] = get_activation( pos_wise_act_type, **activation_parameters ) main_params["conv_mod_act"] = get_activation( conv_mod_act_type, **activation_parameters ) main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate main_params["pos_enc_max_len"] = pos_enc_max_len main_params["simplified_att_score"] = simplified_att_score main_params["norm_type"] = norm_type main_params["conv_mod_norm_type"] = conv_mod_norm_type ( main_params["after_norm_class"], main_params["after_norm_args"], ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial) main_params["blockdrop_rate"] = blockdrop_rate main_params["dynamic_chunk_training"] = dynamic_chunk_training main_params["short_chunk_threshold"] = max(0, short_chunk_threshold) main_params["short_chunk_size"] = max(0, short_chunk_size) main_params["num_left_chunks"] = max(0, num_left_chunks) return main_params
[docs]def build_positional_encoding( block_size: int, configuration: Dict[str, Any] ) -> RelPositionalEncoding: """Build positional encoding block. Args: block_size: Input/output size. configuration: Positional encoding configuration. Returns: : Positional encoding module. """ return RelPositionalEncoding( block_size, configuration.get("pos_enc_dropout_rate", 0.0), max_len=configuration.get("pos_enc_max_len", 5000), )
[docs]def build_input_block( input_size: int, configuration: Dict[str, Union[str, int]], ) -> ConvInput: """Build encoder input block. Args: input_size: Input size. configuration: Input block configuration. Returns: : ConvInput block function. """ return ConvInput( input_size, configuration["conv_size"], configuration["subsampling_factor"], vgg_like=configuration["vgg_like"], output_size=configuration["output_size"], )
[docs]def build_branchformer_block( configuration: List[Dict[str, Any]], main_params: Dict[str, Any], ) -> Branchformer: """Build Branchformer block. Args: configuration: Branchformer block configuration. main_params: Encoder main parameters. Returns: : Branchformer block function. """ hidden_size = configuration["hidden_size"] linear_size = configuration["linear_size"] dropout_rate = configuration.get("dropout_rate", 0.0) conv_mod_norm_class, conv_mod_norm_args = get_normalization( main_params["conv_mod_norm_type"], eps=configuration.get("conv_mod_norm_eps"), partial=configuration.get("conv_mod_norm_partial"), ) conv_mod_args = ( linear_size, configuration["conv_mod_kernel_size"], conv_mod_norm_class, conv_mod_norm_args, dropout_rate, main_params["dynamic_chunk_training"], ) mult_att_args = ( configuration.get("heads", 4), hidden_size, configuration.get("att_dropout_rate", 0.0), main_params["simplified_att_score"], ) norm_class, norm_args = get_normalization( main_params["norm_type"], eps=configuration.get("norm_eps"), partial=configuration.get("norm_partial"), ) return lambda: Branchformer( hidden_size, linear_size, RelPositionMultiHeadedAttention(*mult_att_args), ConvolutionalSpatialGatingUnit(*conv_mod_args), norm_class=norm_class, norm_args=norm_args, dropout_rate=dropout_rate, )
[docs]def build_conformer_block( configuration: List[Dict[str, Any]], main_params: Dict[str, Any], ) -> Conformer: """Build Conformer block. Args: configuration: Conformer block configuration. main_params: Encoder main parameters. Returns: : Conformer block function. """ hidden_size = configuration["hidden_size"] linear_size = configuration["linear_size"] pos_wise_args = ( hidden_size, linear_size, configuration.get("pos_wise_dropout_rate", 0.0), main_params["pos_wise_act"], ) conv_mod_norm_args = { "eps": configuration.get("conv_mod_norm_eps", 1e-05), "momentum": configuration.get("conv_mod_norm_momentum", 0.1), } conv_mod_args = ( hidden_size, configuration["conv_mod_kernel_size"], main_params["conv_mod_act"], conv_mod_norm_args, main_params["dynamic_chunk_training"], ) mult_att_args = ( configuration.get("heads", 4), hidden_size, configuration.get("att_dropout_rate", 0.0), main_params["simplified_att_score"], ) norm_class, norm_args = get_normalization( main_params["norm_type"], eps=configuration.get("norm_eps"), partial=configuration.get("norm_partial"), ) return lambda: Conformer( hidden_size, RelPositionMultiHeadedAttention(*mult_att_args), PositionwiseFeedForward(*pos_wise_args), PositionwiseFeedForward(*pos_wise_args), ConformerConvolution(*conv_mod_args), norm_class=norm_class, norm_args=norm_args, dropout_rate=configuration.get("dropout_rate", 0.0), )
[docs]def build_conv1d_block( configuration: List[Dict[str, Any]], causal: bool, ) -> Conv1d: """Build Conv1d block. Args: configuration: Conv1d block configuration. Returns: : Conv1d block function. """ return lambda: Conv1d( configuration["input_size"], configuration["output_size"], configuration["kernel_size"], stride=configuration.get("stride", 1), dilation=configuration.get("dilation", 1), groups=configuration.get("groups", 1), bias=configuration.get("bias", True), relu=configuration.get("relu", True), batch_norm=configuration.get("batch_norm", False), causal=causal, dropout_rate=configuration.get("dropout_rate", 0.0), )
[docs]def build_ebranchformer_block( configuration: List[Dict[str, Any]], main_params: Dict[str, Any], ) -> EBranchformer: """Build E-Branchformer block. Args: configuration: E-Branchformer block configuration. main_params: Encoder main parameters. Returns: : E-Branchformer block function. """ hidden_size = configuration["hidden_size"] linear_size = configuration["linear_size"] dropout_rate = configuration.get("dropout_rate", 0.0) pos_wise_args = ( hidden_size, linear_size, configuration.get("pos_wise_dropout_rate", 0.0), main_params["pos_wise_act"], ) conv_mod_norm_class, conv_mod_norm_args = get_normalization( main_params["conv_mod_norm_type"], eps=configuration.get("conv_mod_norm_eps"), partial=configuration.get("conv_mod_norm_partial"), ) conv_mod_args = ( linear_size, configuration["conv_mod_kernel_size"], conv_mod_norm_class, conv_mod_norm_args, dropout_rate, main_params["dynamic_chunk_training"], ) mult_att_args = ( configuration.get("heads", 4), hidden_size, configuration.get("att_dropout_rate", 0.0), main_params["simplified_att_score"], ) depthwise_conv_args = ( hidden_size, configuration.get( "depth_conv_kernel_size", configuration["conv_mod_kernel_size"] ), main_params["dynamic_chunk_training"], ) norm_class, norm_args = get_normalization( main_params["norm_type"], eps=configuration.get("norm_eps"), partial=configuration.get("norm_partial"), ) return lambda: EBranchformer( hidden_size, linear_size, RelPositionMultiHeadedAttention(*mult_att_args), PositionwiseFeedForward(*pos_wise_args), PositionwiseFeedForward(*pos_wise_args), ConvolutionalSpatialGatingUnit(*conv_mod_args), DepthwiseConvolution(*depthwise_conv_args), norm_class=norm_class, norm_args=norm_args, dropout_rate=dropout_rate, )
[docs]def build_body_blocks( configuration: List[Dict[str, Any]], main_params: Dict[str, Any], output_size: int, ) -> MultiBlocks: """Build encoder body blocks. Args: configuration: Body blocks configuration. main_params: Encoder main parameters. output_size: Architecture output size. Returns: MultiBlocks function encapsulation all encoder blocks. """ fn_modules = [] extended_conf = [] for c in configuration: if c.get("num_blocks") is not None: extended_conf += c["num_blocks"] * [ {c_i: c[c_i] for c_i in c if c_i != "num_blocks"} ] else: extended_conf += [c] for i, c in enumerate(extended_conf): block_type = c["block_type"] if block_type == "branchformer": module = build_branchformer_block(c, main_params) elif block_type == "conformer": module = build_conformer_block(c, main_params) elif block_type == "conv1d": module = build_conv1d_block(c, main_params["dynamic_chunk_training"]) elif block_type == "ebranchformer": module = build_ebranchformer_block(c, main_params) else: raise NotImplementedError fn_modules.append(module) return MultiBlocks( [fn() for fn in fn_modules], output_size, norm_class=main_params["after_norm_class"], norm_args=main_params["after_norm_args"], blockdrop_rate=main_params["blockdrop_rate"], )