Source code for espnet.nets.pytorch_backend.e2e_tts_tacotron2

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Tacotron 2 related modules."""

import logging

import numpy as np
import torch
import torch.nn.functional as F

from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.rnn.attentions import AttForward, AttForwardTA, AttLoc
from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG, CBHGLoss
from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder
from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.cli_utils import strtobool
from espnet.utils.fill_missing_args import fill_missing_args


[docs]class GuidedAttentionLoss(torch.nn.Module): """Guided attention loss function module. This module calculates the guided attention loss described in `Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`_, which forces the attention to be diagonal. .. _`Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`: https://arxiv.org/abs/1710.08969 """ def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): """Initialize guided attention loss module. Args: sigma (float, optional): Standard deviation to control how close attention to a diagonal. alpha (float, optional): Scaling coefficient (lambda). reset_always (bool, optional): Whether to always reset masks. """ super(GuidedAttentionLoss, self).__init__() self.sigma = sigma self.alpha = alpha self.reset_always = reset_always self.guided_attn_masks = None self.masks = None def _reset_masks(self): self.guided_attn_masks = None self.masks = None
[docs] def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Args: att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). ilens (LongTensor): Batch of input lengths (B,). olens (LongTensor): Batch of output lengths (B,). Returns: Tensor: Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to( att_ws.device ) if self.masks is None: self.masks = self._make_masks(ilens, olens).to(att_ws.device) losses = self.guided_attn_masks * att_ws loss = torch.mean(losses.masked_select(self.masks)) if self.reset_always: self._reset_masks() return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens): n_batches = len(ilens) max_ilen = max(ilens) max_olen = max(olens) guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( ilen, olen, self.sigma ) return guided_attn_masks @staticmethod def _make_guided_attention_mask(ilen, olen, sigma): """Make guided attention mask. Examples: >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) >>> guided_attn_mask.shape torch.Size([5, 5]) >>> guided_attn_mask tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) >>> guided_attn_mask.shape torch.Size([6, 3]) >>> guided_attn_mask tensor([[0.0000, 0.2934, 0.7506], [0.0831, 0.0831, 0.5422], [0.2934, 0.0000, 0.2934], [0.5422, 0.0831, 0.0831], [0.7506, 0.2934, 0.0000], [0.8858, 0.5422, 0.0831]]) """ grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) grid_x, grid_y = grid_x.float().to(olen.device), grid_y.float().to(ilen.device) return 1.0 - torch.exp( -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)) ) @staticmethod def _make_masks(ilens, olens): """Make masks indicating non-padded part. Args: ilens (LongTensor or List): Batch of lengths (B,). olens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor indicating non-padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens, olens = [5, 2], [8, 5] >>> _make_mask(ilens, olens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ in_masks = make_non_pad_mask(ilens) # (B, T_in) out_masks = make_non_pad_mask(olens) # (B, T_out) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
[docs]class Tacotron2Loss(torch.nn.Module): """Loss function module for Tacotron2.""" def __init__( self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0 ): """Initialize Tactoron2 loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token. """ super(Tacotron2Loss, self).__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = torch.nn.L1Loss(reduction=reduction) self.mse_criterion = torch.nn.MSELoss(reduction=reduction) self.bce_criterion = torch.nn.BCEWithLogitsLoss( reduction=reduction, pos_weight=torch.tensor(bce_pos_weight) ) # NOTE(kan-bayashi): register pre hook function for the compatibility self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
[docs] def forward(self, after_outs, before_outs, logits, ys, labels, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). logits (Tensor): Batch of stop logits (B, Lmax). ys (Tensor): Batch of padded target features (B, Lmax, odim). labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ # make mask and apply it if self.use_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) ys = ys.masked_select(masks) after_outs = after_outs.masked_select(masks) before_outs = before_outs.masked_select(masks) labels = labels.masked_select(masks[:, :, 0]) logits = logits.masked_select(masks[:, :, 0]) # calculate loss l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( before_outs, ys ) bce_loss = self.bce_criterion(logits, labels) # make weighted mask and apply it if self.use_weighted_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) weights = masks.float() / masks.sum(dim=1, keepdim=True).float() out_weights = weights.div(ys.size(0) * ys.size(2)) logit_weights = weights.div(ys.size(0)) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() bce_loss = ( bce_loss.mul(logit_weights.squeeze(-1)) .masked_select(masks.squeeze(-1)) .sum() ) return l1_loss, mse_loss, bce_loss
def _load_state_dict_pre_hook( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): """Apply pre hook function before loading state dict. From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but old models do not include it and as a result, it causes missing key error when loading old model parameter. This function solve the issue by adding param in state dict before loading as a pre hook function of the `load_state_dict` method. """ key = prefix + "bce_criterion.pos_weight" if key not in state_dict: state_dict[key] = self.bce_criterion.pos_weight
[docs]class Tacotron2(TTSInterface, torch.nn.Module): """Tacotron2 module for end-to-end text-to-speech (E2E-TTS). This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """
[docs] @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument( "--elayers", default=1, type=int, help="Number of encoder layers" ) group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # attention group.add_argument( "--atype", default="location", type=str, choices=["forward_ta", "forward", "location"], help="Type of attention mechanism", ) group.add_argument( "--adim", default=512, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aconv-chans", default=32, type=int, help="Number of attention convolution channels", ) group.add_argument( "--aconv-filts", default=15, type=int, help="Filter size of attention convolution", ) group.add_argument( "--cumulate-att-w", default=True, type=strtobool, help="Whether or not to cumulate attention weights", ) # decoder group.add_argument( "--dlayers", default=2, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=1024, type=int, help="Number of decoder hidden units" ) group.add_argument( "--prenet-layers", default=2, type=int, help="Number of prenet layers" ) group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument( "--postnet-layers", default=5, type=int, help="Number of postnet layers" ) group.add_argument( "--postnet-chans", default=512, type=int, help="Number of postnet channels" ) group.add_argument( "--postnet-filts", default=5, type=int, help="Filter size of postnet" ) group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # cbhg group.add_argument( "--use-cbhg", default=False, type=strtobool, help="Whether to use CBHG module", ) group.add_argument( "--cbhg-conv-bank-layers", default=8, type=int, help="Number of convoluional bank layers in CBHG", ) group.add_argument( "--cbhg-conv-bank-chans", default=128, type=int, help="Number of convoluional bank channles in CBHG", ) group.add_argument( "--cbhg-conv-proj-filts", default=3, type=int, help="Filter size of convoluional projection layer in CBHG", ) group.add_argument( "--cbhg-conv-proj-chans", default=256, type=int, help="Number of convoluional projection channels in CBHG", ) group.add_argument( "--cbhg-highway-layers", default=4, type=int, help="Number of highway layers in CBHG", ) group.add_argument( "--cbhg-highway-units", default=128, type=int, help="Number of highway units in CBHG", ) group.add_argument( "--cbhg-gru-units", default=256, type=int, help="Number of GRU units in CBHG", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help="Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument( "--dropout-rate", default=0.5, type=float, help="Dropout rate" ) group.add_argument( "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor" ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" ) group.add_argument( "--pretrained-model", default=None, type=str, help="Pretrained model path" ) # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) group.add_argument( "--bce-pos-weight", default=20.0, type=float, help="Positive sample weight in BCE calculation " "(only for use-masking=True)", ) group.add_argument( "--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss", ) group.add_argument( "--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss", ) group.add_argument( "--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss", ) return parser
def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError( "there is no such an activation function. (%s)" % args.output_activation ) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, ) dec_idim = ( args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim ) if args.atype == "location": att = AttLoc( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) elif args.atype == "forward": att = AttForward( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim, ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG( idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units, ) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model)
[docs] def forward( self, xs, ilens, ys, labels, olens, spembs=None, extras=None, *args, **kwargs ): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modifiy mod part of groundtruth if self.reduction_factor > 1: assert olens.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = torch.scatter( labels, 1, (olens - 1).unsqueeze(1), 1.0 ) # see #3388 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens ) loss = l1_loss + mse_loss + bce_loss report_keys = [ {"l1_loss": l1_loss.item()}, {"mse_loss": mse_loss.item()}, {"bce_loss": bce_loss.item()}, ] # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): # length of output for auto-regressive input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss report_keys += [ {"attn_loss": attn_loss.item()}, ] # calculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != extras.shape[1]: extras = extras[:, :max_out] # calculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, extras, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ {"cbhg_l1_loss": cbhg_l1_loss.item()}, {"cbhg_mse_loss": cbhg_mse_loss.item()}, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss
[docs] def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio use_att_constraint = getattr( inference_args, "use_att_constraint", False ) # keep compatibility backward_window = inference_args.backward_window if use_att_constraint else 0 forward_window = inference_args.forward_window if use_att_constraint else 0 # inference h = self.enc.inference(x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference( h, threshold, minlenratio, maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws
[docs] def calculate_all_attentions( self, xs, ilens, ys, spembs=None, keep_tensor=False, *args, **kwargs ): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). keep_tensor (bool, optional): Whether to keep original tensor. Returns: Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax). """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() if keep_tensor: return att_ws else: return att_ws.cpu().numpy()
@property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"] if self.use_guided_attn_loss: plot_keys += ["attn_loss"] if self.use_cbhg: plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"] return plot_keys