Source code for espnet.nets.chainer_backend.transformer.decoder

# encoding: utf-8
"""Class Declaration of Transformer's Decoder."""

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np

from espnet.nets.chainer_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding
from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm
from espnet.nets.chainer_backend.transformer.mask import make_history_mask


[docs]class Decoder(chainer.Chain): """Decoder layer. Args: odim (int): The output dimension. n_layers (int): Number of ecoder layers. n_units (int): Number of attention units. d_units (int): Dimension of input vector of decoder. h (int): Number of attention heads. dropout (float): Dropout rate. initialW (Initializer): Initializer to initialize the weight. initial_bias (Initializer): Initializer to initialize the bias. """ def __init__(self, odim, args, initialW=None, initial_bias=None): """Initialize Decoder.""" super(Decoder, self).__init__() self.sos = odim - 1 self.eos = odim - 1 initialW = chainer.initializers.Uniform if initialW is None else initialW initial_bias = ( chainer.initializers.Uniform if initial_bias is None else initial_bias ) with self.init_scope(): self.output_norm = LayerNorm(args.adim) self.pe = PositionalEncoding(args.adim, args.dropout_rate) stvd = 1.0 / np.sqrt(args.adim) self.output_layer = L.Linear( args.adim, odim, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd), ) self.embed = L.EmbedID( odim, args.adim, ignore_label=-1, initialW=chainer.initializers.Normal(scale=1.0), ) for i in range(args.dlayers): name = "decoders." + str(i) layer = DecoderLayer( args.adim, d_units=args.dunits, h=args.aheads, dropout=args.dropout_rate, initialW=initialW, initial_bias=initial_bias, ) self.add_link(name, layer) self.n_layers = args.dlayers
[docs] def make_attention_mask(self, source_block, target_block): """Prepare the attention mask. Args: source_block (ndarray): Source block with dimensions: (B x S). target_block (ndarray): Target block with dimensions: (B x T). Returns: ndarray: Mask with dimensions (B, S, T). """ mask = (target_block[:, None, :] >= 0) * (source_block[:, :, None] >= 0) # (batch, source_length, target_length) return mask
[docs] def forward(self, ys_pad, source, x_mask): """Forward decoder. :param xp.array e: input token ids, int64 (batch, maxlen_out) :param xp.array yy_mask: input token mask, uint8 (batch, maxlen_out) :param xp.array source: encoded memory, float32 (batch, maxlen_in, feat) :param xp.array xy_mask: encoded memory mask, uint8 (batch, maxlen_in) :return e: decoded token score before softmax (batch, maxlen_out, token) :rtype: chainer.Variable """ xp = self.xp sos = np.array([self.sos], np.int32) ys = [np.concatenate([sos, y], axis=0) for y in ys_pad] e = F.pad_sequence(ys, padding=self.eos).data e = xp.array(e) # mask preparation xy_mask = self.make_attention_mask(e, xp.array(x_mask)) yy_mask = self.make_attention_mask(e, e) yy_mask *= make_history_mask(xp, e) e = self.pe(self.embed(e)) batch, length, dims = e.shape e = e.reshape(-1, dims) source = source.reshape(-1, dims) for i in range(self.n_layers): e = self["decoders." + str(i)](e, source, xy_mask, yy_mask, batch) return self.output_layer(self.output_norm(e)).reshape(batch, length, -1)
[docs] def recognize(self, e, yy_mask, source): """Process recognition function.""" e = self.forward(e, source, yy_mask) return F.log_softmax(e, axis=-1)