espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit
Less than 1 minute
espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit
class espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit(size: int, kernel_size: int, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0, causal: bool = False)
Bases: Module
Convolutional Spatial Gating Unit module definition.
- Parameters:
- size – Initial size to determine the number of channels.
- kernel_size – Size of the convolving kernel.
- norm_class – Normalization module class.
- norm_args – Normalization module arguments.
- dropout_rate – Dropout rate.
- causal – Whether to use causal convolution (set to True if streaming).
Construct a ConvolutionalSpatialGatingUnit object.
forward(x: Tensor, mask: Tensor | None = None, cache: Tensor | None = None) → Tuple[Tensor, Tensor]
Compute convolution module.
- Parameters:
- x – ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
- mask – Source mask. (B, T_2)
- cache – ConvolutionalSpationGatingUnit input cache. (1, D_hidden, conv_kernel)
- Returns: ConvolutionalSpatialGatingUnit output sequences. (B, ?, D_hidden)
- Return type: x