espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA
espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA
class espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA(size: int, num_heads: int = 4, activation: Module = ReLU(), truncation_length: int | None = None)
Bases: Module
MultiHeadDampedEMA module definition.
- Parameters:
- size – Module size.
- num_heads – Number of attention heads.
- activation – Activation function type.
- truncation_length – Maximum length for truncation.
Construct an MultiHeadDampedEMA object.
compute_ema_coefficients() → Tuple[Tensor, Tensor]
Compute EMA coefficients.
Parameters:None
Returns: Damping factor / P-th order coefficient. : (size, num_heads, 1)
prev_timestep_weight: Previous timestep weight / Q-th order coefficient. : (size, num_heads, 1)
Return type: damping_factor
compute_ema_kernel(length: int) → Tensor
Compute EMA kernel / vandermonde product.
- Parameters:length – Sequence length.
- Returns: EMA kernel / Vandermonde product. (size, L)
ema_one_step(x: Tensor, state: Tensor | None = None) → Tuple[Tensor, Tensor]
Perform exponential moving average for a single step.
- Parameters:
- x – MultiHeadDampedEMA input sequences. (B, D, 1)
- state – MultiHeadDampedEMA state. (B, D, num_heads)
- Returns: MultiHeadDamped output sequences. (B, 1, D) new_state: MultiHeadDampedEMA state. (B, D, num_heads)
- Return type: out
forward(x: Tensor, mask: Tensor | None = None, state: Dict[str, Tensor] | None = None) → Tensor | None
Compute multi-dimensional damped EMA.
- Parameters:
- x – MultiHeadDampedEMA input sequence. (L, B, D)
- mask – Sequence mask. (B, 1, L)
- state – MultiHeadDampedEMA state. (B, D, num_heads)
- Returns: MultiHeadDampedEMA output sequence. (B, L, D) new_state: MultiHeadDampedEMA state. (B, D, num_heads)
- Return type: x
get_ema_coefficients() → Tuple[Tensor, Tensor]
Get EMA coefficients.
- Parameters:None
- Returns: Damping factor / P-th order coefficient. (size, num_heads, 1) : Previous timestep weight / Q-th order coefficient. (size, num_heads, 1)
reset_parameters(val: float = 0.0, std1: float = 0.2, std2: float = 1.0) → None
Reset module parameters.
- Parameters:
- val – Initialization value.
- std1 – Main standard deviation.
- std2 – Secondary standard deviation.