espnet2.asr.encoder.branchformer_encoder.BranchformerEncoder
espnet2.asr.encoder.branchformer_encoder.BranchformerEncoder
class espnet2.asr.encoder.branchformer_encoder.BranchformerEncoder(input_size: int, output_size: int = 256, use_attn: bool = True, attention_heads: int = 4, attention_layer_type: str = 'rel_selfattn', pos_enc_layer_type: str = 'rel_pos', rel_pos_type: str = 'latest', use_cgmlp: bool = True, cgmlp_linear_units: int = 2048, cgmlp_conv_kernel: int = 31, use_linear_after_conv: bool = False, gate_activation: str = 'identity', merge_method: str = 'concat', cgmlp_weight: float | List[float] = 0.5, attn_branch_drop_rate: float | List[float] = 0.0, num_blocks: int = 12, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str | None = 'conv2d', zero_triu: bool = False, padding_idx: int = -1, stochastic_depth_rate: float | List[float] = 0.0, qk_norm: bool = False, use_flash_attn: bool = True)
Bases: AbsEncoder
Branchformer encoder module.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor | None]
Calculate forward propagation.
- Parameters:
- xs_pad (torch.Tensor) – Input tensor (#batch, L, input_size).
- ilens (torch.Tensor) – Input length (#batch).
- prev_states (torch.Tensor) – Not to be used now.
- Returns: Output tensor (#batch, L, output_size). torch.Tensor: Output length (#batch). torch.Tensor: Not to be used now.
- Return type: torch.Tensor
output_size() → int