espnet2.asr_transducer.normalization.RMSNorm
Less than 1 minute
espnet2.asr_transducer.normalization.RMSNorm
class espnet2.asr_transducer.normalization.RMSNorm(normalized_shape: int, eps: float = 1e-05, partial: float = 0.0)
Bases: Module
RMSNorm module definition.
Reference: https://arxiv.org/pdf/1910.07467.pdf
- Parameters:
- normalized_shape – Expected size.
- eps – Value added to the denominator for numerical stability.
- partial – Value defining the part of the input used for RMS stats.
Construct a RMSNorm object.
forward(x: Tensor) → Tensor
Compute RMS normalization.
- Parameters:x – Input sequences. (B, T, D_hidden)
- Returns: Output sequences. (B, T, D_hidden)
- Return type: x