Source code for espnet2.asr.decoder.transformer_decoder

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Decoder definition."""
from typing import Any, List, Sequence, Tuple

import torch
from typeguard import typechecked

from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution
from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.scorer_interface import BatchScorerInterface


[docs]class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): """Base class of Transfomer decoder module. Args: vocab_size: output dim encoder_output_size: dimension of attention attention_heads: the number of heads of multi head attention linear_units: the number of units of position-wise feed forward num_blocks: the number of decoder blocks dropout_rate: dropout rate self_attention_dropout_rate: dropout rate for attention input_layer: input layer type use_output_layer: whether to use output layer pos_enc_class: PositionalEncoding or ScaledPositionalEncoding normalize_before: whether to use layer_norm before the first block concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, ): super().__init__() attention_dim = encoder_output_size if input_layer == "embed": self.embed = torch.nn.Sequential( torch.nn.Embedding(vocab_size, attention_dim), pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(vocab_size, attention_dim), torch.nn.LayerNorm(attention_dim), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), pos_enc_class(attention_dim, positional_dropout_rate), ) else: raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") self.normalize_before = normalize_before if self.normalize_before: self.after_norm = LayerNorm(attention_dim) if use_output_layer: self.output_layer = torch.nn.Linear(attention_dim, vocab_size) else: self.output_layer = None self._output_size_bf_softmax = attention_dim # Must set by the inheritance self.decoders = None
[docs] def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, return_hs: bool = False, return_all_hs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) return_hs: (bool) whether to return the last hidden output before output layer return_all_hs: (bool) whether to return all the hidden intermediates Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad # tgt_mask: (B, 1, L) tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m memory = hs_pad memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( memory.device ) # Padding for Longformer if memory_mask.shape[-1] != memory.shape[1]: padlen = memory.shape[1] - memory_mask.shape[-1] memory_mask = torch.nn.functional.pad( memory_mask, (0, padlen), "constant", False ) x = self.embed(tgt) intermediate_outs = [] for layer_idx, decoder_layer in enumerate(self.decoders): x, tgt_mask, memory, memory_mask = decoder_layer( x, tgt_mask, memory, memory_mask ) if return_all_hs: intermediate_outs.append(x) if self.normalize_before: x = self.after_norm(x) if return_hs: hidden = x if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) if return_hs: return (x, hidden), olens elif return_all_hs: return (x, intermediate_outs), olens return x, olens
[docs] def forward_one_step( self, tgt: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor = None, *, cache: List[torch.Tensor] = None, return_hs: bool = False, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward one step. Args: tgt: input token ids, int64 (batch, maxlen_out) tgt_mask: input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoded memory mask (batch, 1, maxlen_in) cache: cached output list of (batch, max_time_out-1, size) return_hs: dec hidden state corresponding to ys, used for searchable hidden ints Returns: y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ x = self.embed(tgt) if cache is None: cache = [None] * len(self.decoders) new_cache = [] for c, decoder in zip(cache, self.decoders): x, tgt_mask, memory, memory_mask = decoder( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(x) if self.normalize_before: y = self.after_norm(x[:, -1]) else: y = x[:, -1] if return_hs: hidden = y if self.output_layer is not None: y = torch.log_softmax(self.output_layer(y), dim=-1) if return_hs: return (y, hidden), new_cache return y, new_cache
[docs] def score(self, ys, state, x, return_hs=False): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) if return_hs: (logp, hs), state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state, return_hs=return_hs, ) return logp.squeeze(0), hs, state else: logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state, return_hs=return_hs, ) return logp.squeeze(0), state
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor, return_hs: bool = False, ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) if return_hs: (logp, hs), states = self.forward_one_step( ys, ys_mask, xs, cache=batch_state, return_hs=return_hs ) else: logp, states = self.forward_one_step( ys, ys_mask, xs, cache=batch_state, return_hs=return_hs ) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] if return_hs: return (logp, hs), state_list return logp, state_list
[docs]class TransformerDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, layer_drop_rate: float = 0.0, ): super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, MultiHeadedAttention( attention_heads, attention_dim, self_attention_dropout_rate ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), layer_drop_rate, )
[docs]class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, conv_wshare: int = 4, conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), conv_usebias: int = False, ): if len(conv_kernel_length) != num_blocks: raise ValueError( "conv_kernel_length must have equal number of values to num_blocks: " f"{len(conv_kernel_length)} != {num_blocks}" ) super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, LightweightConvolution( wshare=conv_wshare, n_feat=attention_dim, dropout_rate=self_attention_dropout_rate, kernel_size=conv_kernel_length[lnum], use_kernel_mask=True, use_bias=conv_usebias, ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), )
[docs]class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, conv_wshare: int = 4, conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), conv_usebias: int = False, ): if len(conv_kernel_length) != num_blocks: raise ValueError( "conv_kernel_length must have equal number of values to num_blocks: " f"{len(conv_kernel_length)} != {num_blocks}" ) super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, LightweightConvolution2D( wshare=conv_wshare, n_feat=attention_dim, dropout_rate=self_attention_dropout_rate, kernel_size=conv_kernel_length[lnum], use_kernel_mask=True, use_bias=conv_usebias, ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), )
[docs]class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, conv_wshare: int = 4, conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), conv_usebias: int = False, ): if len(conv_kernel_length) != num_blocks: raise ValueError( "conv_kernel_length must have equal number of values to num_blocks: " f"{len(conv_kernel_length)} != {num_blocks}" ) super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, DynamicConvolution( wshare=conv_wshare, n_feat=attention_dim, dropout_rate=self_attention_dropout_rate, kernel_size=conv_kernel_length[lnum], use_kernel_mask=True, use_bias=conv_usebias, ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), )
[docs]class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, conv_wshare: int = 4, conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), conv_usebias: int = False, ): if len(conv_kernel_length) != num_blocks: raise ValueError( "conv_kernel_length must have equal number of values to num_blocks: " f"{len(conv_kernel_length)} != {num_blocks}" ) super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, DynamicConvolution2D( wshare=conv_wshare, n_feat=attention_dim, dropout_rate=self_attention_dropout_rate, kernel_size=conv_kernel_length[lnum], use_kernel_mask=True, use_bias=conv_usebias, ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), )
[docs]class TransformerMDDecoder(BaseTransformerDecoder): @typechecked def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, use_speech_attn: bool = True, ): super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) attention_dim = encoder_output_size self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, MultiHeadedAttention( attention_heads, attention_dim, self_attention_dropout_rate ), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ( MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ) if use_speech_attn else None ), ), ) self.use_speech_attn = use_speech_attn
[docs] def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, speech: torch.Tensor = None, speech_lens: torch.Tensor = None, return_hs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) return_hs: dec hidden state corresponding to ys, used for searchable hidden ints Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad # tgt_mask: (B, 1, L) tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m memory = hs_pad memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( memory.device ) if speech is not None: speech_mask = (~make_pad_mask(speech_lens, maxlen=speech.size(1)))[ :, None, : ].to(speech.device) else: speech_mask = None x = self.embed(tgt) if self.use_speech_attn: x, tgt_mask, memory, memory_mask, _, speech, speech_mask = self.decoders( x, tgt_mask, memory, memory_mask, None, speech, speech_mask ) else: x, tgt_mask, memory, memory_mask = self.decoders( x, tgt_mask, memory, memory_mask ) if self.normalize_before: x = self.after_norm(x) if return_hs: hs_asr = x if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) if return_hs: return x, olens, hs_asr return x, olens
[docs] def forward_one_step( self, tgt: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor = None, *, speech: torch.Tensor = None, speech_mask: torch.Tensor = None, cache: List[torch.Tensor] = None, return_hs: bool = False, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward one step. Args: tgt: input token ids, int64 (batch, maxlen_out) tgt_mask: input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoded memory mask (batch, 1, maxlen_in) speech: encoded speech, float32 (batch, maxlen_in, feat) speech_mask: encoded memory mask (batch, 1, maxlen_in) cache: cached output list of (batch, max_time_out-1, size) return_hs: dec hidden state corresponding to ys, used for searchable hidden ints Returns: y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ x = self.embed(tgt) if cache is None: cache = [None] * len(self.decoders) new_cache = [] for c, decoder in zip(cache, self.decoders): if self.use_speech_attn: x, tgt_mask, memory, memory_mask, _, speech, speech_mask = decoder( x, tgt_mask, memory, memory_mask, cache=c, pre_memory=speech, pre_memory_mask=speech_mask, ) else: x, tgt_mask, memory, memory_mask = decoder( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(x) if self.normalize_before: y = self.after_norm(x[:, -1]) else: y = x[:, -1] if return_hs: h_asr = y if self.output_layer is not None: y = torch.log_softmax(self.output_layer(y), dim=-1) if return_hs: return y, h_asr, new_cache return y, new_cache
[docs] def score(self, ys, state, x, speech=None): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), speech=speech.unsqueeze(0) if speech is not None else None, cache=state, ) return logp.squeeze(0), state
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor, speech: torch.Tensor = None, ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) logp, states = self.forward_one_step( ys, ys_mask, xs, speech=speech, cache=batch_state ) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list