espnet2.asr.encoder.e_branchformer_ctc_encoder.EBranchformerCTCEncoder
espnet2.asr.encoder.e_branchformer_ctc_encoder.EBranchformerCTCEncoder
class espnet2.asr.encoder.e_branchformer_ctc_encoder.EBranchformerCTCEncoder(input_size: int, output_size: int = 256, attention_heads: int = 4, attention_layer_type: str = 'rel_selfattn', pos_enc_layer_type: str = 'rel_pos', rel_pos_type: str = 'latest', cgmlp_linear_units: int = 2048, cgmlp_conv_kernel: int = 31, use_linear_after_conv: bool = False, gate_activation: str = 'identity', num_blocks: int = 12, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str | None = 'conv2d8', zero_triu: bool = False, padding_idx: int = -1, layer_drop_rate: float = 0.0, max_pos_emb_len: int = 5000, use_ffn: bool = False, macaron_ffn: bool = False, ffn_activation_type: str = 'swish', linear_units: int = 2048, positionwise_layer_type: str = 'linear', merge_conv_kernel: int = 3, interctc_layer_idx=None, interctc_use_conditioning: bool = False, use_cross_attention=True, use_flash_attn: bool = False)
Bases: AbsEncoder
E-Branchformer encoder module.
Compared to the original encoder in e_branchformer_encoder.py, this variant supports additional cross-attention modules. Additionally, it supports extra prefix tokens for the input. This is useful for language and task conditioning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None, ctc: CTC | None = None, max_layer: int | None = None, prefix_embeds: tensor | None = None, memory=None, memory_mask=None) → Tuple[Tensor, Tensor, Tensor | None]
Calculate forward propagation.
- Parameters:
- xs_pad (torch.Tensor) – Input tensor (#batch, L, input_size).
- ilens (torch.Tensor) – Input length (#batch).
- prev_states (torch.Tensor) – Not to be used now.
- ctc (CTC) – Intermediate CTC module.
- max_layer (int) – Layer depth below which InterCTC is applied.
- Returns: Output tensor (#batch, L, output_size). torch.Tensor: Output length (#batch). torch.Tensor: Not to be used now.
- Return type: torch.Tensor
output_size() → int