espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer
espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer
class espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer(block_size: int, linear_size: int, self_att: ~torch.nn.modules.module.Module, feed_forward: ~torch.nn.modules.module.Module, feed_forward_macaron: ~torch.nn.modules.module.Module, conv_mod: ~torch.nn.modules.module.Module, depthwise_conv_mod: ~torch.nn.modules.module.Module, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0)
Bases: Module
E-Branchformer module definition.
Reference: https://arxiv.org/pdf/2210.00077.pdf
- Parameters:
- block_size – Input/output size.
- linear_size – Linear layers’ hidden size.
- self_att – Self-attention module instance.
- feed_forward – Feed-forward module instance.
- feed_forward_macaron – Feed-forward module instance for macaron network.
- conv_mod – ConvolutionalSpatialGatingUnit module instance.
- depthwise_conv_mod – DepthwiseConvolution module instance.
- norm_class – Normalization class.
- norm_args – Normalization module arguments.
- dropout_rate – Dropout rate.
Construct a E-Branchformer object.
chunk_forward(x: Tensor, pos_enc: Tensor, mask: Tensor, left_context: int = 0) → Tuple[Tensor, Tensor]
Encode chunk of input sequence.
- Parameters:
- x – E-Branchformer input sequences. (B, T, D_block)
- pos_enc – Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask – Source mask. (B, T_2)
- left_context – Number of previous frames the attention module can see in current chunk.
- Returns: E-Branchformer output sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- Return type: x
forward(x: Tensor, pos_enc: Tensor, mask: Tensor, chunk_mask: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor]
Encode input sequences.
- Parameters:
- x – E-Branchformer input sequences. (B, T, D_block)
- pos_enc – Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask – Source mask. (B, T)
- chunk_mask – Chunk mask. (T_2, T_2)
- Returns: E-Branchformer output sequences. (B, T, D_block) mask: Source mask. (B, T) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- Return type: x
reset_streaming_cache(left_context: int, device: device) → None
Initialize/Reset self-attention and convolution modules cache for streaming.
- Parameters:
- left_context – Number of previous frames the attention module can see in current chunk.
- device – Device to use for cache tensor.