Source code for espnet.nets.pytorch_backend.transformer.decoder

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Decoder definition."""

import logging
from typing import Any, List, Tuple

import torch

from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution
from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.scorer_interface import BatchScorerInterface


def _pre_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):
    # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
    rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)


[docs]class Decoder(BatchScorerInterface, torch.nn.Module): """Transfomer decoder module. Args: odim (int): Output diminsion. self_attention_layer_type (str): Self-attention layer type. attention_dim (int): Dimension of attention. attention_heads (int): The number of heads of multi head attention. conv_wshare (int): The number of kernel of convolution. Only used in self_attention_layer_type == "lightconv*" or "dynamiconv*". conv_kernel_length (Union[int, str]): Kernel size str of convolution (e.g. 71_71_71_71_71_71). Only used in self_attention_layer_type == "lightconv*" or "dynamiconv*". conv_usebias (bool): Whether to use bias in convolution. Only used in self_attention_layer_type == "lightconv*" or "dynamiconv*". linear_units (int): The number of units of position-wise feed forward. num_blocks (int): The number of decoder blocks. dropout_rate (float): Dropout rate. positional_dropout_rate (float): Dropout rate after adding positional encoding. self_attention_dropout_rate (float): Dropout rate in self-attention. src_attention_dropout_rate (float): Dropout rate in source-attention. input_layer (Union[str, torch.nn.Module]): Input layer type. use_output_layer (bool): Whether to use output layer. pos_enc_class (torch.nn.Module): Positional encoding module class. `PositionalEncoding `or `ScaledPositionalEncoding` normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ def __init__( self, odim, selfattention_layer_type="selfattn", attention_dim=256, attention_heads=4, conv_wshare=4, conv_kernel_length=11, conv_usebias=False, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, self_attention_dropout_rate=0.0, src_attention_dropout_rate=0.0, input_layer="embed", use_output_layer=True, pos_enc_class=PositionalEncoding, normalize_before=True, concat_after=False, ): """Construct an Decoder object.""" torch.nn.Module.__init__(self) self._register_load_state_dict_pre_hook(_pre_hook) if input_layer == "embed": self.embed = torch.nn.Sequential( torch.nn.Embedding(odim, attention_dim), pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(odim, attention_dim), torch.nn.LayerNorm(attention_dim), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), pos_enc_class(attention_dim, positional_dropout_rate), ) elif isinstance(input_layer, torch.nn.Module): self.embed = torch.nn.Sequential( input_layer, pos_enc_class(attention_dim, positional_dropout_rate) ) else: raise NotImplementedError("only `embed` or torch.nn.Module is supported.") self.normalize_before = normalize_before # self-attention module definition if selfattention_layer_type == "selfattn": logging.info("decoder self-attention layer type = self-attention") decoder_selfattn_layer = MultiHeadedAttention decoder_selfattn_layer_args = [ ( attention_heads, attention_dim, self_attention_dropout_rate, ) ] * num_blocks elif selfattention_layer_type == "lightconv": logging.info("decoder self-attention layer type = lightweight convolution") decoder_selfattn_layer = LightweightConvolution decoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, self_attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), True, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "lightconv2d": logging.info( "decoder self-attention layer " "type = lightweight convolution 2-dimensional" ) decoder_selfattn_layer = LightweightConvolution2D decoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, self_attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), True, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "dynamicconv": logging.info("decoder self-attention layer type = dynamic convolution") decoder_selfattn_layer = DynamicConvolution decoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, self_attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), True, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "dynamicconv2d": logging.info( "decoder self-attention layer type = dynamic convolution 2-dimensional" ) decoder_selfattn_layer = DynamicConvolution2D decoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, self_attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), True, conv_usebias, ) for lnum in range(num_blocks) ] self.decoders = repeat( num_blocks, lambda lnum: DecoderLayer( attention_dim, decoder_selfattn_layer(*decoder_selfattn_layer_args[lnum]), MultiHeadedAttention( attention_heads, attention_dim, src_attention_dropout_rate ), PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), ) self.selfattention_layer_type = selfattention_layer_type if self.normalize_before: self.after_norm = LayerNorm(attention_dim) if use_output_layer: self.output_layer = torch.nn.Linear(attention_dim, odim) else: self.output_layer = None
[docs] def forward(self, tgt, tgt_mask, memory, memory_mask): """Forward decoder. Args: tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out) if input_layer == "embed". In the other case, input tensor (#batch, maxlen_out, odim). tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out). dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ (include 1.2). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat). memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ (include 1.2). Returns: torch.Tensor: Decoded token score before softmax (#batch, maxlen_out, odim) if use_output_layer is True. In the other case,final block outputs (#batch, maxlen_out, attention_dim). torch.Tensor: Score mask before softmax (#batch, maxlen_out). """ x = self.embed(tgt) x, tgt_mask, memory, memory_mask = self.decoders( x, tgt_mask, memory, memory_mask ) if self.normalize_before: x = self.after_norm(x) if self.output_layer is not None: x = self.output_layer(x) return x, tgt_mask
[docs] def forward_one_step(self, tgt, tgt_mask, memory, *, cache=None): """Forward one step. Args: tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out). tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out). dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ (include 1.2). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). Returns: torch.Tensor: Output tensor (batch, maxlen_out, odim). List[torch.Tensor]: List of cache tensors of each decoder layer. """ x = self.embed(tgt) if cache is None: cache = [None] * len(self.decoders) new_cache = [] for c, decoder in zip(cache, self.decoders): x, tgt_mask, memory, memory_mask = decoder( x, tgt_mask, memory, None, cache=c ) new_cache.append(x) if self.normalize_before: y = self.after_norm(x[:, -1]) else: y = x[:, -1] if self.output_layer is not None: y = torch.log_softmax(self.output_layer(y), dim=-1) return y, new_cache
# beam search API (see ScorerInterface)
[docs] def score(self, ys, state, x): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( f"{self.selfattention_layer_type} does not support cached decoding." ) state = None logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state ) return logp.squeeze(0), state
# batch beam search API (see BatchScorerInterface)
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list