Source code for espnet2.asr.encoder.transformer_encoder_multispkr

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Encoder definition."""
from typing import Optional, Tuple

import torch
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
    Conv1dLinear,
    MultiLayeredConv1d,
)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling1,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
    TooShortUttError,
    check_short_utt,
)


[docs]class TransformerEncoder(AbsEncoder): """Transformer encoder module. Args: input_size: input dim output_size: dimension of attention attention_heads: the number of heads of multi head attention linear_units: the number of units of position-wise feed forward num_blocks: the number of recognition encoder blocks num_blocks_sd: the number of speaker dependent encoder blocks dropout_rate: dropout rate attention_dropout_rate: dropout rate in attention positional_dropout_rate: dropout rate after adding positional encoding input_layer: input layer type pos_enc_class: PositionalEncoding or ScaledPositionalEncoding normalize_before: whether to use layer_norm before the first block concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) positionwise_layer_type: linear of conv1d positionwise_conv_kernel_size: kernel size of positionwise conv1d layer padding_idx: padding_idx for input_layer=embed num_inf: number of inference output """ @typechecked def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, num_blocks_sd: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: Optional[str] = "conv2d", pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, num_inf: int = 1, ): super().__init__() self._output_size = output_size if input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(input_size, output_size), torch.nn.LayerNorm(output_size), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), pos_enc_class(output_size, positional_dropout_rate), ) elif input_layer == "conv2d": self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) elif input_layer == "conv2d1": self.embed = Conv2dSubsampling1(input_size, output_size, dropout_rate) elif input_layer == "conv2d2": self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) elif input_layer == "conv2d6": self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) elif input_layer == "conv2d8": self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) elif input_layer == "embed": self.embed = torch.nn.Sequential( torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), pos_enc_class(output_size, positional_dropout_rate), ) elif input_layer is None: self.embed = torch.nn.Sequential( pos_enc_class(output_size, positional_dropout_rate) ) else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before if positionwise_layer_type == "linear": positionwise_layer = PositionwiseFeedForward positionwise_layer_args = ( output_size, linear_units, dropout_rate, ) elif positionwise_layer_type == "conv1d": positionwise_layer = MultiLayeredConv1d positionwise_layer_args = ( output_size, linear_units, positionwise_conv_kernel_size, dropout_rate, ) elif positionwise_layer_type == "conv1d-linear": positionwise_layer = Conv1dLinear positionwise_layer_args = ( output_size, linear_units, positionwise_conv_kernel_size, dropout_rate, ) else: raise NotImplementedError("Support only linear or conv1d.") self.encoders = repeat( num_blocks, lambda lnum: EncoderLayer( output_size, MultiHeadedAttention( attention_heads, output_size, attention_dropout_rate ), positionwise_layer(*positionwise_layer_args), dropout_rate, normalize_before, concat_after, ), ) if self.normalize_before: self.after_norm = LayerNorm(output_size) self.num_inf = num_inf self.encoders_sd = torch.nn.ModuleList( [ repeat( num_blocks_sd, lambda lnum: EncoderLayer( output_size, MultiHeadedAttention( attention_heads, output_size, attention_dropout_rate ), positionwise_layer(*positionwise_layer_args), dropout_rate, normalize_before, concat_after, ), ) for _ in range(num_inf) ] )
[docs] def output_size(self) -> int: return self._output_size
[docs] def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling1) or isinstance(self.embed, Conv2dSubsampling2) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) xs_sd, masks_sd = [None] * self.num_inf, [None] * self.num_inf for ns in range(self.num_inf): xs_sd[ns], masks_sd[ns] = self.encoders_sd[ns](xs_pad, masks) xs_sd[ns], masks_sd[ns] = self.encoders(xs_sd[ns], masks_sd[ns]) # Enc_rec if self.normalize_before: xs_sd[ns] = self.after_norm(xs_sd[ns]) olens = [masks_sd[ns].squeeze(1).sum(1) for ns in range(self.num_inf)] return torch.stack(xs_sd, dim=1), torch.stack(olens, dim=1), None