espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer
espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer
class espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer(block_size: int, linear_size: int, self_att: ~torch.nn.modules.module.Module, conv_mod: ~torch.nn.modules.module.Module, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0)
Bases: Module
Branchformer module definition.
Reference: https://arxiv.org/pdf/2207.02971.pdf
- Parameters:
- block_size – Input/output size.
- linear_size – Linear layers’ hidden size.
- self_att – Self-attention module instance.
- conv_mod – Convolution module instance.
- norm_class – Normalization class.
- norm_args – Normalization module arguments.
- dropout_rate – Dropout rate.
Construct a Branchformer object.
chunk_forward(x: Tensor, pos_enc: Tensor, mask: Tensor, left_context: int = 0) → Tuple[Tensor, Tensor]
Encode chunk of input sequence.
- Parameters:
- x – Branchformer input sequences. (B, T, D_block)
- pos_enc – Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask – Source mask. (B, T_2)
- left_context – Number of previous frames the attention module can see in current chunk.
- Returns: Branchformer output sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- Return type: x
forward(x: Tensor, pos_enc: Tensor, mask: Tensor, chunk_mask: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor]
Encode input sequences.
- Parameters:
- x – Branchformer input sequences. (B, T, D_block)
- pos_enc – Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask – Source mask. (B, T)
- chunk_mask – Chunk mask. (T_2, T_2)
- Returns: Branchformer output sequences. (B, T, D_block) mask: Source mask. (B, T) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- Return type: x
reset_streaming_cache(left_context: int, device: device) → None
Initialize/Reset self-attention and convolution modules cache for streaming.
- Parameters:
- left_context – Number of previous frames the attention module can see in current chunk.
- device – Device to use for cache tensor.