Source code for espnet.nets.pytorch_backend.tacotron2.decoder

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Tacotron2 decoder related modules."""

import torch
import torch.nn.functional as F

from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA


[docs]def decoder_init(m): """Initialize decoder parameters.""" if isinstance(m, torch.nn.Conv1d): torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
[docs]class ZoneOutCell(torch.nn.Module): """ZoneOut Cell module. This is a module of zoneout described in `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_. This code is modified from `eladhoffer/seq2seq.pytorch`_. Examples: >>> lstm = torch.nn.LSTMCell(16, 32) >>> lstm = ZoneOutCell(lstm, 0.5) .. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`: https://arxiv.org/abs/1606.01305 .. _`eladhoffer/seq2seq.pytorch`: https://github.com/eladhoffer/seq2seq.pytorch """ def __init__(self, cell, zoneout_rate=0.1): """Initialize zone out cell module. Args: cell (torch.nn.Module): Pytorch recurrent cell module e.g. `torch.nn.Module.LSTMCell`. zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0. """ super(ZoneOutCell, self).__init__() self.cell = cell self.hidden_size = cell.hidden_size self.zoneout_rate = zoneout_rate if zoneout_rate > 1.0 or zoneout_rate < 0.0: raise ValueError( "zoneout probability must be in the range from 0.0 to 1.0." )
[docs] def forward(self, inputs, hidden): """Calculate forward propagation. Args: inputs (Tensor): Batch of input tensor (B, input_size). hidden (tuple): - Tensor: Batch of initial hidden states (B, hidden_size). - Tensor: Batch of initial cell states (B, hidden_size). Returns: tuple: - Tensor: Batch of next hidden states (B, hidden_size). - Tensor: Batch of next cell states (B, hidden_size). """ next_hidden = self.cell(inputs, hidden) next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate) return next_hidden
def _zoneout(self, h, next_h, prob): # apply recursively if isinstance(h, tuple): num_h = len(h) if not isinstance(prob, tuple): prob = tuple([prob] * num_h) return tuple( [self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)] ) if self.training: mask = h.new(*h.size()).bernoulli_(prob) return mask * h + (1 - mask) * next_h else: return prob * h + (1 - prob) * next_h
[docs]class Prenet(torch.nn.Module): """Prenet module for decoder of Spectrogram prediction network. This is a module of Prenet in the decoder of Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Prenet preforms nonlinear conversion of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions. Note: This module alway applies dropout even in evaluation. See the detail in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5): """Initialize prenet module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. n_layers (int, optional): The number of prenet layers. n_units (int, optional): The number of prenet units. """ super(Prenet, self).__init__() self.dropout_rate = dropout_rate self.prenet = torch.nn.ModuleList() for layer in range(n_layers): n_inputs = idim if layer == 0 else n_units self.prenet += [ torch.nn.Sequential(torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU()) ]
[docs] def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Batch of input tensors (B, ..., idim). Returns: Tensor: Batch of output tensors (B, ..., odim). """ for i in range(len(self.prenet)): # we make this part non deterministic. See the above note. x = F.dropout(self.prenet[i](x), self.dropout_rate) return x
[docs]class Postnet(torch.nn.Module): """Postnet module for Spectrogram prediction network. This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet predicts refines the predicted Mel-filterbank of the decoder, which helps to compensate the detail structure of spectrogram. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True, ): """Initialize postnet module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. n_layers (int, optional): The number of layers. n_filts (int, optional): The number of filter size. n_units (int, optional): The number of filter channels. use_batch_norm (bool, optional): Whether to use batch normalization.. dropout_rate (float, optional): Dropout rate.. """ super(Postnet, self).__init__() self.postnet = torch.nn.ModuleList() for layer in range(n_layers - 1): ichans = odim if layer == 0 else n_chans ochans = odim if layer == n_layers - 1 else n_chans if use_batch_norm: self.postnet += [ torch.nn.Sequential( torch.nn.Conv1d( ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.BatchNorm1d(ochans), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), ) ] else: self.postnet += [ torch.nn.Sequential( torch.nn.Conv1d( ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), ) ] ichans = n_chans if n_layers != 1 else odim if use_batch_norm: self.postnet += [ torch.nn.Sequential( torch.nn.Conv1d( ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.BatchNorm1d(odim), torch.nn.Dropout(dropout_rate), ) ] else: self.postnet += [ torch.nn.Sequential( torch.nn.Conv1d( ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Dropout(dropout_rate), ) ]
[docs] def forward(self, xs): """Calculate forward propagation. Args: xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). Returns: Tensor: Batch of padded output tensor. (B, odim, Tmax). """ for i in range(len(self.postnet)): xs = self.postnet[i](xs) return xs
[docs]class Decoder(torch.nn.Module): """Decoder module of Spectrogram prediction network. This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The decoder generates the sequence of features from the sequence of the hidden states. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, idim, odim, att, dlayers=2, dunits=1024, prenet_layers=2, prenet_units=256, postnet_layers=5, postnet_chans=512, postnet_filts=5, output_activation_fn=None, cumulate_att_w=True, use_batch_norm=True, use_concate=True, dropout_rate=0.5, zoneout_rate=0.1, reduction_factor=1, ): """Initialize Tacotron2 decoder module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. att (torch.nn.Module): Instance of attention class. dlayers (int, optional): The number of decoder lstm layers. dunits (int, optional): The number of decoder lstm units. prenet_layers (int, optional): The number of prenet layers. prenet_units (int, optional): The number of prenet units. postnet_layers (int, optional): The number of postnet layers. postnet_filts (int, optional): The number of postnet filter size. postnet_chans (int, optional): The number of postnet filter channels. output_activation_fn (torch.nn.Module, optional): Activation function for outputs. cumulate_att_w (bool, optional): Whether to cumulate previous attention weight. use_batch_norm (bool, optional): Whether to use batch normalization. use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs. dropout_rate (float, optional): Dropout rate. zoneout_rate (float, optional): Zoneout rate. reduction_factor (int, optional): Reduction factor. """ super(Decoder, self).__init__() # store the hyperparameters self.idim = idim self.odim = odim self.att = att self.output_activation_fn = output_activation_fn self.cumulate_att_w = cumulate_att_w self.use_concate = use_concate self.reduction_factor = reduction_factor # check attention type if isinstance(self.att, AttForwardTA): self.use_att_extra_inputs = True else: self.use_att_extra_inputs = False # define lstm network prenet_units = prenet_units if prenet_layers != 0 else odim self.lstm = torch.nn.ModuleList() for layer in range(dlayers): iunits = idim + prenet_units if layer == 0 else dunits lstm = torch.nn.LSTMCell(iunits, dunits) if zoneout_rate > 0.0: lstm = ZoneOutCell(lstm, zoneout_rate) self.lstm += [lstm] # define prenet if prenet_layers > 0: self.prenet = Prenet( idim=odim, n_layers=prenet_layers, n_units=prenet_units, dropout_rate=dropout_rate, ) else: self.prenet = None # define postnet if postnet_layers > 0: self.postnet = Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=dropout_rate, ) else: self.postnet = None # define projection layers iunits = idim + dunits if use_concate else dunits self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False) self.prob_out = torch.nn.Linear(iunits, reduction_factor) # initialize self.apply(decoder_init) def _zero_state(self, hs): init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) return init_hs
[docs] def forward(self, hs, hlens, ys): """Calculate forward propagation. Args: hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). hlens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim). Returns: Tensor: Batch of output tensors after postnet (B, Lmax, odim). Tensor: Batch of output tensors before postnet (B, Lmax, odim). Tensor: Batch of logits of stop prediction (B, Lmax). Tensor: Batch of attention weights (B, Lmax, Tmax). Note: This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] # length list should be list of int hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(hs.size(0), self.odim) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence outs, logits, att_ws = [], [], [] for y in ys.transpose(0, 1): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) if ( type(att_w) is list ): # for multihead attention (used for translatotron in s2st) att_w = torch.stack(att_w, dim=1) prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for i in range(1, len(self.lstm)): z_list[i], c_list[i] = self.lstm[i]( z_list[i - 1], (z_list[i], c_list[i]) ) zcs = ( torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1] ) outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)] logits += [self.prob_out(zcs)] att_ws += [att_w] prev_out = y # teacher forcing if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w logits = torch.cat(logits, dim=1) # (B, Lmax) before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax) att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) if self.reduction_factor > 1: before_outs = before_outs.view( before_outs.size(0), self.odim, -1 ) # (B, odim, Lmax) if self.postnet is not None: after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax) else: after_outs = before_outs before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim) after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim) logits = logits # apply activation function for scaling if self.output_activation_fn is not None: before_outs = self.output_activation_fn(before_outs) after_outs = self.output_activation_fn(after_outs) return after_outs, before_outs, logits, att_ws
[docs] def inference( self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0, use_att_constraint=False, backward_window=None, forward_window=None, ): """Generate the sequence of features given the sequences of characters. Args: h (Tensor): Input sequence of encoder hidden states (T, C). threshold (float, optional): Threshold to stop generation. minlenratio (float, optional): Minimum length ratio. If set to 1.0 and the length of input is 10, the minimum length of outputs will be 10 * 1 = 10. minlenratio (float, optional): Minimum length ratio. If set to 10 and the length of input is 10, the maximum length of outputs will be 10 * 10 = 100. use_att_constraint (bool): Whether to apply attention constraint introduced in `Deep Voice 3`_. backward_window (int): Backward window size in attention constraint. forward_window (int): Forward window size in attention constraint. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). Note: This computation is performed in auto-regressive manner. .. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 """ # setup assert len(h.size()) == 2 hs = h.unsqueeze(0) ilens = [h.size(0)] maxlen = int(h.size(0) * maxlenratio) minlen = int(h.size(0) * minlenratio) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(1, self.odim) # initialize attention prev_att_w = None self.att.reset() # setup for attention constraint if use_att_constraint: last_attended_idx = 0 else: last_attended_idx = None # loop for an output sequence idx = 0 outs, att_ws, probs = [], [], [] while True: # updated index idx += self.reduction_factor # decoder calculation if self.use_att_extra_inputs: att_c, att_w = self.att( hs, ilens, z_list[0], prev_att_w, prev_out, last_attended_idx=last_attended_idx, backward_window=backward_window, forward_window=forward_window, ) else: att_c, att_w = self.att( hs, ilens, z_list[0], prev_att_w, last_attended_idx=last_attended_idx, backward_window=backward_window, forward_window=forward_window, ) if ( type(att_w) is list ): # for multihead attention (used for translatotron in s2st) att_w = torch.stack(att_w, dim=1) att_ws += [att_w] prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for i in range(1, len(self.lstm)): z_list[i], c_list[i] = self.lstm[i]( z_list[i - 1], (z_list[i], c_list[i]) ) zcs = ( torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1] ) outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...] probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...] if self.output_activation_fn is not None: prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim) else: prev_out = outs[-1][:, :, -1] # (1, odim) if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w if use_att_constraint: last_attended_idx = int(att_w.argmax()) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=2) # (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) att_ws = torch.cat(att_ws, dim=0) break if self.output_activation_fn is not None: outs = self.output_activation_fn(outs) return outs, probs, att_ws
[docs] def calculate_all_attentions(self, hs, hlens, ys): """Calculate all of the attention weights. Args: hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). hlens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim). Returns: numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). Note: This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] # length list should be list of int hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(hs.size(0), self.odim) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence att_ws = [] for y in ys.transpose(0, 1): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) att_ws += [att_w] prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for i in range(1, len(self.lstm)): z_list[i], c_list[i] = self.lstm[i]( z_list[i - 1], (z_list[i], c_list[i]) ) prev_out = y # teacher forcing if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) return att_ws