Source code for espnet.nets.pytorch_backend.rnn.encoders

import logging

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from espnet.nets.e2e_asr_common import get_vgg2l_odim
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device


[docs]class RNNP(torch.nn.Module): """RNN with projection layer module :param int idim: dimension of inputs :param int elayers: number of encoder layers :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) :param int hdim: number of projection units :param np.ndarray subsample: list of subsampling numbers :param float dropout: dropout rate :param str typ: The RNN type """ def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): super(RNNP, self).__init__() bidir = typ[0] == "b" for i in range(elayers): if i == 0: inputdim = idim else: inputdim = hdim RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU rnn = RNN( inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True ) setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) # bottleneck layer to merge if bidir: setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) else: setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) self.elayers = elayers self.cdim = cdim self.subsample = subsample self.typ = typ self.bidir = bidir self.dropout = dropout
[docs] def forward(self, xs_pad, ilens, prev_state=None): """RNNP forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous RNN states :return: batch of hidden state sequences (B, Tmax, hdim) :rtype: torch.Tensor """ logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) elayer_states = [] for layer in range(self.elayers): if not isinstance(ilens, torch.Tensor): ilens = torch.tensor(ilens) xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) if self.training: rnn.flatten_parameters() if prev_state is not None and rnn.bidirectional: prev_state = reset_backward_rnn_state(prev_state) ys, states = rnn( xs_pack, hx=None if prev_state is None else prev_state[layer] ) elayer_states.append(states) # ys: utt list of frame x cdim x 2 (2: means bidirectional) ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) sub = self.subsample[layer + 1] if sub > 1: ys_pad = ys_pad[:, ::sub] ilens = torch.tensor([int(i + 1) // sub for i in ilens]) # (sum _utt frame_utt) x dim projection_layer = getattr(self, "bt%d" % layer) projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2))) xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) if layer < self.elayers - 1: xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout)) return xs_pad, ilens, elayer_states # x: utt list of frame x dim
[docs]class RNN(torch.nn.Module): """RNN module :param int idim: dimension of inputs :param int elayers: number of encoder layers :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) :param int hdim: number of final projection units :param float dropout: dropout rate :param str typ: The RNN type """ def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"): super(RNN, self).__init__() bidir = typ[0] == "b" self.nbrnn = ( torch.nn.LSTM( idim, cdim, elayers, batch_first=True, dropout=dropout, bidirectional=bidir, ) if "lstm" in typ else torch.nn.GRU( idim, cdim, elayers, batch_first=True, dropout=dropout, bidirectional=bidir, ) ) if bidir: self.l_last = torch.nn.Linear(cdim * 2, hdim) else: self.l_last = torch.nn.Linear(cdim, hdim) self.typ = typ
[docs] def forward(self, xs_pad, ilens, prev_state=None): """RNN forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous RNN states :return: batch of hidden state sequences (B, Tmax, eprojs) :rtype: torch.Tensor """ logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) if not isinstance(ilens, torch.Tensor): ilens = torch.tensor(ilens) xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) if self.training: self.nbrnn.flatten_parameters() if prev_state is not None and self.nbrnn.bidirectional: # We assume that when previous state is passed, # it means that we're streaming the input # and therefore cannot propagate backward BRNN state # (otherwise it goes in the wrong direction) prev_state = reset_backward_rnn_state(prev_state) ys, states = self.nbrnn(xs_pack, hx=prev_state) # ys: utt list of frame x cdim x 2 (2: means bidirectional) ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) # (sum _utt frame_utt) x dim projected = torch.tanh( self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2))) ) xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) return xs_pad, ilens, states # x: utt list of frame x dim
[docs]def reset_backward_rnn_state(states): """Sets backward BRNN states to zeroes Useful in processing of sliding windows over the inputs """ if isinstance(states, (list, tuple)): for state in states: state[1::2] = 0.0 else: states[1::2] = 0.0 return states
[docs]class VGG2L(torch.nn.Module): """VGG-like module :param int in_channel: number of input channels """ def __init__(self, in_channel=1): super(VGG2L, self).__init__() # CNN layer (VGG motivated) self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) self.in_channel = in_channel
[docs] def forward(self, xs_pad, ilens, **kwargs): """VGG2L forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) :rtype: torch.Tensor """ logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) # x: utt x frame x dim # xs_pad = F.pad_sequence(xs_pad) # x: utt x 1 (input channel num) x frame x dim xs_pad = xs_pad.view( xs_pad.size(0), xs_pad.size(1), self.in_channel, xs_pad.size(2) // self.in_channel, ).transpose(1, 2) # NOTE: max_pool1d ? xs_pad = F.relu(self.conv1_1(xs_pad)) xs_pad = F.relu(self.conv1_2(xs_pad)) xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) xs_pad = F.relu(self.conv2_1(xs_pad)) xs_pad = F.relu(self.conv2_2(xs_pad)) xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) if torch.is_tensor(ilens): ilens = ilens.cpu().numpy() else: ilens = np.array(ilens, dtype=np.float32) ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) ilens = np.array( np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64 ).tolist() # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) xs_pad = xs_pad.transpose(1, 2) xs_pad = xs_pad.contiguous().view( xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3) ) return xs_pad, ilens, None # no state in this layer
[docs]class Encoder(torch.nn.Module): """Encoder module :param str etype: type of encoder network :param int idim: number of dimensions of encoder network :param int elayers: number of layers of encoder network :param int eunits: number of lstm units of encoder network :param int eprojs: number of projection units of encoder network :param np.ndarray subsample: list of subsampling numbers :param float dropout: dropout rate :param int in_channel: number of input channels """ def __init__( self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1 ): super(Encoder, self).__init__() typ = etype.lstrip("vgg").rstrip("p") if typ not in ["lstm", "gru", "blstm", "bgru"]: logging.error("Error: need to specify an appropriate encoder architecture") if etype.startswith("vgg"): if etype[-1] == "p": self.enc = torch.nn.ModuleList( [ VGG2L(in_channel), RNNP( get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, eprojs, subsample, dropout, typ=typ, ), ] ) logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder") else: self.enc = torch.nn.ModuleList( [ VGG2L(in_channel), RNN( get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, eprojs, dropout, typ=typ, ), ] ) logging.info("Use CNN-VGG + " + typ.upper() + " for encoder") self.conv_subsampling_factor = 4 else: if etype[-1] == "p": self.enc = torch.nn.ModuleList( [RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)] ) logging.info(typ.upper() + " with every-layer projection for encoder") else: self.enc = torch.nn.ModuleList( [RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)] ) logging.info(typ.upper() + " without projection for encoder") self.conv_subsampling_factor = 1
[docs] def forward(self, xs_pad, ilens, prev_states=None): """Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) :return: batch of hidden state sequences (B, Tmax, eprojs) :rtype: torch.Tensor """ if prev_states is None: prev_states = [None] * len(self.enc) assert len(prev_states) == len(self.enc) current_states = [] for module, prev_state in zip(self.enc, prev_states): xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) current_states.append(states) # make mask to remove bias value in padded part mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) return xs_pad.masked_fill(mask, 0.0), ilens, current_states
[docs]def encoder_for(args, idim, subsample): """Instantiates an encoder module given the program arguments :param Namespace args: The arguments :param int or List of integer idim: dimension of input, e.g. 83, or List of dimensions of inputs, e.g. [83,83] :param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]] :rtype torch.nn.Module :return: The encoder module """ num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility if num_encs == 1: # compatible with single encoder asr mode return Encoder( args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate, ) elif num_encs >= 1: enc_list = torch.nn.ModuleList() for idx in range(num_encs): enc = Encoder( args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx], args.dropout_rate[idx], ) enc_list.append(enc) return enc_list else: raise ValueError( "Number of encoders needs to be more than one. {}".format(num_encs) )