Source code for espnet.nets.pytorch_backend.lm.seq_rnn

"""Sequential implementation of Recurrent Neural Network Language Model."""

import torch
import torch.nn as nn
import torch.nn.functional as F

from espnet.nets.lm_interface import LMInterface

[docs]class SequentialRNNLM(LMInterface, torch.nn.Module): """Sequential RNNLM. See also: """
[docs] @staticmethod def add_arguments(parser): """Add arguments to command line argument parser.""" parser.add_argument( "--type", type=str, default="lstm", nargs="?", choices=["lstm", "gru"], help="Which type of RNN to use", ) parser.add_argument( "--layer", "-l", type=int, default=2, help="Number of hidden layers" ) parser.add_argument( "--unit", "-u", type=int, default=650, help="Number of hidden units" ) parser.add_argument( "--dropout-rate", type=float, default=0.5, help="dropout probability" ) return parser
def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ torch.nn.Module.__init__(self) self._setup( rnn_type=args.type.upper(), ntoken=n_vocab, ninp=args.unit, nhid=args.unit, nlayers=args.layer, dropout=args.dropout_rate, ) def _setup( self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False ): self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) if rnn_type in ["LSTM", "GRU"]: self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) else: try: nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] except KeyError: raise ValueError( "An invalid option for `--model` was supplied, " "options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']" ) self.rnn = nn.RNN( ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout ) self.decoder = nn.Linear(nhid, ntoken) # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # # and # "Tying Word Vectors and Word Classifiers: # A Loss Framework for Language Modeling" (Inan et al. 2016) # if tie_weights: if nhid != ninp: raise ValueError( "When using the tied flag, nhid must be equal to emsize" ) self.decoder.weight = self.encoder.weight self._init_weights() self.rnn_type = rnn_type self.nhid = nhid self.nlayers = nlayers def _init_weights(self): # NOTE: original init in pytorch/examples # initrange = 0.1 #, initrange) # #, initrange) # NOTE: our init for param in self.parameters():, 0.1)
[docs] def forward(self, x, t): """Compute LM loss value from buffer sequences. Args: x (torch.Tensor): Input ids. (batch, len) t (torch.Tensor): Target ids. (batch, len) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar) Notes: The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ y = self._before_loss(x, None)[0] mask = (x != 0).to(y.dtype) loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") logp = loss * mask.view(-1) logp = logp.sum() count = mask.sum() return logp / count, logp, count
def _before_loss(self, input, hidden): emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder( output.view(output.size(0) * output.size(1), output.size(2)) ) return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
[docs] def init_state(self, x): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ bsz = 1 weight = next(self.parameters()) if self.rnn_type == "LSTM": return ( weight.new_zeros(self.nlayers, bsz, self.nhid), weight.new_zeros(self.nlayers, bsz, self.nhid), ) else: return weight.new_zeros(self.nlayers, bsz, self.nhid)
[docs] def score(self, y, state, x): """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and next state for ys """ y, new_state = self._before_loss(y[-1].view(1, 1), state) logp = y.log_softmax(dim=-1).view(-1) return logp, new_state