Source code for espnet2.asr.encoder.linear_encoder

# Copyright 2023 Jiatong Shi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Linear encoder definition."""

from typing import Optional, Tuple

import torch
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
    TooShortUttError,
    check_short_utt,
)


[docs]class LinearEncoder(AbsEncoder): """Linear encoder module. Args: input_size: input dim output_size: dimension of attention linear_units: the number of units of position-wise feed forward dropout_rate: dropout rate input_layer: input layer type normalize_before: whether to use layer_norm before the first block padding_idx: padding_idx for input_layer=embed """ @typechecked def __init__( self, input_size: int, output_size: int = 256, dropout_rate: float = 0.1, input_layer: Optional[str] = "conv2d", normalize_before: bool = True, padding_idx: int = -1, ): super().__init__() self._output_size = output_size if input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(input_size, output_size), torch.nn.LayerNorm(output_size), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), ) elif input_layer == "conv2d": self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) elif input_layer == "conv2d2": self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) elif input_layer == "conv2d6": self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) elif input_layer == "conv2d8": self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) elif input_layer == "embed": self.embed = ( torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), ) elif input_layer is None: if input_size == output_size: self.embed = None else: self.embed = torch.nn.Linear(input_size, output_size) else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before if self.normalize_before: self.after_norm = LayerNorm(output_size)
[docs] def output_size(self) -> int: return self._output_size
[docs] def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if self.embed is None: xs_pad = xs_pad elif ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None