Source code for espnet.nets.pytorch_backend.transducer.conv1d_nets

"""Convolution networks definition for custom archictecture."""

from typing import Optional, Tuple, Union

import torch


[docs]class Conv1d(torch.nn.Module): """1D convolution module for custom encoder. Args: idim: Input dimension. odim: Output dimension. kernel_size: Size of the convolving kernel. stride: Stride of the convolution. dilation: Spacing between the kernel points. groups: Number of blocked connections from input channels to output channels. bias: Whether to add a learnable bias to the output. batch_norm: Whether to use batch normalization after convolution. relu: Whether to use a ReLU activation after convolution. dropout_rate: Dropout rate. """ def __init__( self, idim: int, odim: int, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1, dilation: Union[int, Tuple] = 1, groups: Union[int, Tuple] = 1, bias: bool = True, batch_norm: bool = False, relu: bool = True, dropout_rate: float = 0.0, ): """Construct a Conv1d module object.""" super().__init__() self.conv = torch.nn.Conv1d( idim, odim, kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, ) self.dropout = torch.nn.Dropout(p=dropout_rate) if relu: self.relu_func = torch.nn.ReLU() if batch_norm: self.bn = torch.nn.BatchNorm1d(odim) self.relu = relu self.batch_norm = batch_norm self.padding = dilation * (kernel_size - 1) self.stride = stride self.out_pos = torch.nn.Linear(idim, odim)
[docs] def forward( self, sequence: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]: """Forward ConvEncoderLayer module object. Args: sequence: Input sequences. (B, T, D_in) or (B, T, D_in), (B, 2 * (T - 1), D_att) mask: Mask of input sequences. (B, 1, T) Returns: sequence: Output sequences. (B, sub(T), D_out) or (B, sub(T), D_out), (B, 2 * (sub(T) - 1), D_att) mask: Mask of output sequences. (B, 1, sub(T)) """ if isinstance(sequence, tuple): sequence, pos_embed = sequence[0], sequence[1] else: sequence, pos_embed = sequence, None sequence = sequence.transpose(1, 2) sequence = self.conv(sequence) if self.batch_norm: sequence = self.bn(sequence) sequence = self.dropout(sequence) if self.relu: sequence = self.relu_func(sequence) sequence = sequence.transpose(1, 2) mask = self.create_new_mask(mask) if pos_embed is not None: pos_embed = self.create_new_pos_embed(pos_embed) return (sequence, pos_embed), mask return sequence, mask
[docs] def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor: """Create new mask. Args: mask: Mask of input sequences. (B, 1, T) Returns: mask: Mask of output sequences. (B, 1, sub(T)) """ if mask is None: return mask if self.padding != 0: mask = mask[:, :, : -self.padding] mask = mask[:, :, :: self.stride] return mask
[docs] def create_new_pos_embed(self, pos_embed: torch.Tensor) -> torch.Tensor: """Create new positional embedding vector. Args: pos_embed: Input sequences positional embedding. (B, 2 * (T - 1), D_att) Return: pos_embed: Output sequences positional embedding. (B, 2 * (sub(T) - 1), D_att) """ pos_embed_positive = pos_embed[:, : pos_embed.size(1) // 2 + 1, :] pos_embed_negative = pos_embed[:, pos_embed.size(1) // 2 :, :] if self.padding != 0: pos_embed_positive = pos_embed_positive[:, : -self.padding, :] pos_embed_negative = pos_embed_negative[:, : -self.padding, :] pos_embed_positive = pos_embed_positive[:, :: self.stride, :] pos_embed_negative = pos_embed_negative[:, :: self.stride, :] pos_embed = torch.cat([pos_embed_positive, pos_embed_negative[:, 1:, :]], dim=1) return self.out_pos(pos_embed)
[docs]class CausalConv1d(torch.nn.Module): """1D causal convolution module for custom decoder. Args: idim: Input dimension. odim: Output dimension. kernel_size: Size of the convolving kernel. stride: Stride of the convolution. dilation: Spacing between the kernel points. groups: Number of blocked connections from input channels to output channels. bias: Whether to add a learnable bias to the output. batch_norm: Whether to apply batch normalization. relu: Whether to pass final output through ReLU activation. dropout_rate: Dropout rate. """ def __init__( self, idim: int, odim: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, batch_norm: bool = False, relu: bool = True, dropout_rate: float = 0.0, ): """Construct a CausalConv1d object.""" super().__init__() self.padding = (kernel_size - 1) * dilation self.causal_conv1d = torch.nn.Conv1d( idim, odim, kernel_size=kernel_size, stride=stride, padding=self.padding, dilation=dilation, groups=groups, bias=bias, ) self.dropout = torch.nn.Dropout(p=dropout_rate) if batch_norm: self.bn = torch.nn.BatchNorm1d(odim) if relu: self.relu_func = torch.nn.ReLU() self.batch_norm = batch_norm self.relu = relu
[docs] def forward( self, sequence: torch.Tensor, mask: torch.Tensor, cache: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward CausalConv1d for custom decoder. Args: sequence: CausalConv1d input sequences. (B, U, D_in) mask: Mask of CausalConv1d input sequences. (B, 1, U) Returns: sequence: CausalConv1d output sequences. (B, sub(U), D_out) mask: Mask of CausalConv1d output sequences. (B, 1, sub(U)) """ sequence = sequence.transpose(1, 2) sequence = self.causal_conv1d(sequence) if self.padding != 0: sequence = sequence[:, :, : -self.padding] if self.batch_norm: sequence = self.bn(sequence) sequence = self.dropout(sequence) if self.relu: sequence = self.relu_func(sequence) sequence = sequence.transpose(1, 2) return sequence, mask