Source code for espnet.nets.pytorch_backend.e2e_tts_fastspeech

# Copyright 2019 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""FastSpeech related modules."""

import logging

import torch
import torch.nn.functional as F

from espnet.asr.asr_utils import get_model_conf, torch_load
from espnet.nets.pytorch_backend.fastspeech.duration_calculator import (  # noqa: H301
    DurationCalculator,
)
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import (  # noqa: H301
    DurationPredictor,
    DurationPredictorLoss,
)
from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, make_pad_mask
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.embedding import (
    PositionalEncoding,
    ScaledPositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.initializer import initialize
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.cli_utils import strtobool
from espnet.utils.fill_missing_args import fill_missing_args


[docs]class FeedForwardTransformerLoss(torch.nn.Module): """Loss function module for feed-forward Transformer.""" def __init__(self, use_masking=True, use_weighted_masking=False): """Initialize feed-forward Transformer loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ super(FeedForwardTransformerLoss, self).__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = torch.nn.L1Loss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction)
[docs] def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). d_outs (Tensor): Batch of outputs of duration predictor (B, Tmax). ys (Tensor): Batch of target features (B, Lmax, odim). ds (Tensor): Batch of durations (B, Tmax). ilens (LongTensor): Batch of the lengths of each input (B,). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Duration predictor loss value. """ # apply mask to remove padded part if self.use_masking: duration_masks = make_non_pad_mask(ilens).to(ys.device) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) before_outs = before_outs.masked_select(out_masks) after_outs = ( after_outs.masked_select(out_masks) if after_outs is not None else None ) ys = ys.masked_select(out_masks) # calculate loss l1_loss = self.l1_criterion(before_outs, ys) if after_outs is not None: l1_loss += self.l1_criterion(after_outs, ys) duration_loss = self.duration_criterion(d_outs, ds) # make weighted mask and apply it if self.use_weighted_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() out_weights /= ys.size(0) * ys.size(2) duration_masks = make_non_pad_mask(ilens).to(ys.device) duration_weights = ( duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() ) duration_weights /= ds.size(0) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() duration_loss = ( duration_loss.mul(duration_weights).masked_select(duration_masks).sum() ) return l1_loss, duration_loss
[docs]class FeedForwardTransformer(TTSInterface, torch.nn.Module): """Feed Forward Transformer for TTS a.k.a. FastSpeech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf """
[docs] @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("feed-forward transformer model setting") # network structure related group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aheads", default=4, type=int, help="Number of heads for multi head attention", ) group.add_argument( "--elayers", default=6, type=int, help="Number of encoder layers" ) group.add_argument( "--eunits", default=1536, type=int, help="Number of encoder hidden units" ) group.add_argument( "--dlayers", default=6, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=1536, type=int, help="Number of decoder hidden units" ) group.add_argument( "--positionwise-layer-type", default="linear", type=str, choices=["linear", "conv1d", "conv1d-linear"], help="Positionwise layer type.", ) group.add_argument( "--positionwise-conv-kernel-size", default=3, type=int, help="Kernel size of positionwise conv1d layer", ) group.add_argument( "--postnet-layers", default=0, type=int, help="Number of postnet layers" ) group.add_argument( "--postnet-chans", default=256, type=int, help="Number of postnet channels" ) group.add_argument( "--postnet-filts", default=5, type=int, help="Filter size of postnet" ) group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help="Use trainable scaled positional encoding " "instead of the fixed scale one", ) group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block", ) group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block", ) group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help="Whether to concatenate attention layer's input and output in encoder", ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help="Whether to concatenate attention layer's input and output in decoder", ) group.add_argument( "--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor", ) group.add_argument( "--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor", ) group.add_argument( "--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor", ) group.add_argument( "--teacher-model", default=None, type=str, nargs="?", help="Teacher model file path", ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor" ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spk-embed-integration-type", type=str, default="add", choices=["add", "concat"], help="How to integrate speaker embedding", ) # training related group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="How to initialize transformer parameters", ) group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding", ) group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding", ) group.add_argument( "--transformer-lr", default=1.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps", ) group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention", ) group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding", ) group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention", ) group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder except " "for attention and pos encoding", ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding", ) group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention", ) group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention", ) group.add_argument( "--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor", ) group.add_argument( "--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet", ) group.add_argument( "--transfer-encoder-from-teacher", default=True, type=strtobool, help="Whether to transfer teacher's parameters", ) group.add_argument( "--transferred-encoder-module", default="all", type=str, choices=["all", "embed"], help="Encoder modeules to be trasferred from teacher", ) # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) return parser
def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimensions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # define encoder encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args.transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim ) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args.transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = ( None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate, ) ) # initialize parameters self._reset_parameters( init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha, ) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking ) def _forward( self, xs, ilens, ys=None, olens=None, spembs=None, ds=None, is_inference=False, alpha=1.0, ): # forward encoder x_masks = self._source_mask(ilens).to(xs.device) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) else: if ds is None: with torch.no_grad(): ds = self.duration_calculator( xs, ilens, ys, olens, spembs ) # (B, Tmax) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds) # (B, Lmax, adim) # forward decoder if olens is not None: if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in).to(xs.device) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view( zs.size(0), -1, self.odim ) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) if is_inference: return before_outs, after_outs, d_outs else: return before_outs, after_outs, ds, d_outs
[docs] def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) xs = xs[:, : max(ilens)] ys = ys[:, : max(olens)] if extras is not None: extras = extras[:, : max(ilens)].squeeze(-1) # forward propagation before_outs, after_outs, ds, d_outs = self._forward( xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False ) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: l1_loss, duration_loss = self.criterion( None, before_outs, d_outs, ys, ds, ilens, olens ) else: l1_loss, duration_loss = self.criterion( after_outs, before_outs, d_outs, ys, ds, ilens, olens ) loss = l1_loss + duration_loss report_keys = [ {"l1_loss": l1_loss.item()}, {"duration_loss": duration_loss.item()}, {"loss": loss.item()}, ] # report extra information if self.use_scaled_pos_enc: report_keys += [ {"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, {"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, ] self.reporter.report(report_keys) return loss
[docs] def calculate_all_attentions( self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs ): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # remove unnecessary padded part (for multi-gpus) xs = xs[:, : max(ilens)] ys = ys[:, : max(olens)] if extras is not None: extras = extras[:, : max(ilens)].squeeze(-1) # forward propagation outs = self._forward( xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False )[1] att_ws_dict = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens.tolist()) ] elif "self" in name: attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["predicted_fbank"] = [ m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict
[docs] def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): Dummy for compatibility. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) else: spembs = None # get option alpha = getattr(inference_args, "fastspeech_alpha", 1.0) # inference _, outs, _ = self._forward( xs, ilens, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) return outs[0], None, None
def _integrate_with_spk_embed(self, hs, spembs): """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens) return x_masks.unsqueeze(-2) def _load_teacher_model(self, model_path): # get teacher model config idim, odim, args = get_model_conf(model_path) # assert dimension is the same between teacher and studnet assert idim == self.idim assert odim == self.odim assert args.reduction_factor == self.reduction_factor # load teacher model from espnet.utils.dynamic_import import dynamic_import model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) torch_load(model_path, model) # freeze teacher model parameters for p in model.parameters(): p.requires_grad = False return model def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _transfer_from_teacher(self, transferred_encoder_module): if transferred_encoder_module == "all": for (n1, p1), (n2, p2) in zip( self.encoder.named_parameters(), self.teacher.encoder.named_parameters() ): assert n1 == n2, "It seems that encoder structure is different." assert p1.shape == p2.shape, "It seems that encoder size is different." p1.data.copy_(p2.data) elif transferred_encoder_module == "embed": student_shape = self.encoder.embed[0].weight.data.shape teacher_shape = self.teacher.encoder.embed[0].weight.data.shape assert ( student_shape == teacher_shape ), "It seems that embed dimension is different." self.encoder.embed[0].weight.data.copy_( self.teacher.encoder.embed[0].weight.data ) else: raise NotImplementedError("Support only all or embed.") @property def attention_plot_class(self): """Return plot class for attention weight plot.""" # Lazy import to avoid chainer dependency from espnet.nets.pytorch_backend.e2e_tts_transformer import TTSPlot return TTSPlot @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "duration_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] return plot_keys