Source code for espnet2.svs.xiaoice.XiaoiceSing

# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Copyright 2021 Renmin University of China (Shuai Guo)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""XiaoiceSing related modules."""

import logging
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from typeguard import typechecked

from espnet2.svs.abs_svs import AbsSVS
from espnet2.svs.xiaoice.loss import XiaoiceSing2Loss
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.torch_utils.initialize import initialize
from espnet.nets.pytorch_backend.conformer.encoder import (  # noqa: H301
    Encoder as ConformerEncoder,
)
from espnet.nets.pytorch_backend.e2e_tts_fastspeech import (
    FeedForwardTransformerLoss as XiaoiceSingLoss,  # NOQA
)
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor
from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, make_pad_mask
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
from espnet.nets.pytorch_backend.transformer.embedding import (
    PositionalEncoding,
    ScaledPositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.encoder import (  # noqa: H301
    Encoder as TransformerEncoder,
)


[docs]class XiaoiceSing(AbsSVS): """XiaoiceSing module for Singing Voice Synthesis. This is a module of XiaoiceSing. A high-quality singing voice synthesis system which employs an integrated network for spectrum, F0 and duration modeling. It follows the main architecture of FastSpeech while proposing some singing-specific design: 1) Add features from musical score (e.g.note pitch and length) 2) Add a residual connection in F0 prediction to attenuate off-key issues 3) The duration of all the phonemes in a musical note is accumulated to calculate the syllable duration loss for rhythm enhancement (syllable loss) .. _`XiaoiceSing: A High-Quality and Integrated Singing Voice Synthesis System`: https://arxiv.org/pdf/2006.06261.pdf """ @typechecked def __init__( self, # network structure related idim: int, odim: int, midi_dim: int = 129, duration_dim: int = 500, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, postnet_dropout_rate: float = 0.5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, loss_function: str = "XiaoiceSing2", # FastSpeech1, XiaoiceSing2 loss_type: str = "L1", lambda_mel: float = 1, lambda_dur: float = 0.1, lambda_pitch: float = 0.01, lambda_vuv: float = 0.01, ): """Initialize XiaoiceSing module. Args: idim (int): Dimension of the label inputs. odim (int): Dimension of the outputs. midi_dim (int): Dimension of the midi inputs. duration_dim (int): Dimension of the duration inputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. postnet_layers (int): Number of postnet layers. postnet_chans (int): Number of postnet channels. postnet_filts (int): Kernel size of postnet. postnet_dropout_rate (float): Dropout rate in postnet. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. loss_function (str): Loss functions ("FastSpeech1" or "XiaoiceSing2") loss_type (str): Mel loss type ("L1" (MAE), "L2" (MSE) or "L1+L2") lambda_mel (float): Loss scaling coefficient for Mel loss. lambda_dur (float): Loss scaling coefficient for duration loss. lambda_pitch (float): Loss scaling coefficient for pitch loss. lambda_vuv (float): Loss scaling coefficient for VUV loss. """ super().__init__() # store hyperparameters self.idim = idim self.midi_dim = midi_dim self.duration_dim = duration_dim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_function = loss_function self.loss_type = loss_type self.lambda_mel = lambda_mel self.lambda_dur = lambda_dur self.lambda_pitch = lambda_pitch self.lambda_vuv = lambda_vuv # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'." ) if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'." ) elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder self.phone_encode_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx ) self.midi_encode_layer = torch.nn.Embedding( num_embeddings=midi_dim, embedding_dim=adim, padding_idx=self.padding_idx, ) self.duration_encode_layer = torch.nn.Embedding( num_embeddings=duration_dim, embedding_dim=adim, padding_idx=self.padding_idx, ) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=None, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=None, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.linear_projection = torch.nn.Linear(adim, (odim + 2) * reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions if self.loss_function == "FastSpeech1": self.criterion = XiaoiceSingLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking ) elif self.loss_function == "XiaoiceSing2": self.criterion = XiaoiceSing2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking ) else: raise ValueError(f"{self.loss_function} is not supported.")
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, melody_lengths: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, duration_lengths: Optional[Dict[str, torch.Tensor]] = None, slur: torch.LongTensor = None, slur_lengths: torch.Tensor = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, flag_IsValid=False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). melody_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded melody (B, ). pitch (FloatTensor): Batch of padded f0 (B, Tmax). pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ). duration (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded duration (B, Tmax). duration_length (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of the lengths of padded duration (B, ). slur (LongTensor): Batch of padded slur (B, Tmax). slur_lengths (LongTensor): Batch of the lengths of padded slur (B, ). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ if joint_training: label = label midi = melody label_lengths = label_lengths midi_lengths = melody_lengths duration_lengths = duration_lengths duration_ = duration ds = duration else: label = label["score"] midi = melody["score"] duration_ = duration["score_phn"] label_lengths = label_lengths["score"] midi_lengths = melody_lengths["score"] duration_lengths = duration_lengths["score_phn"] ds = duration["lab"] feats = feats[:, : feats_lengths.max()] # for data-parallel midi = midi[:, : midi_lengths.max()] # for data-parallel label = label[:, : label_lengths.max()] # for data-parallel duration_ = duration_[:, : duration_lengths.max()] # for data-parallel olens = feats_lengths if self.loss_function == "XiaoiceSing2": pitch = pitch[:, : pitch_lengths.max()] log_f0 = torch.clamp(pitch, min=0) vuv = log_f0 != 0 batch_size = text.size(0) label_emb = self.phone_encode_layer(label) midi_emb = self.midi_encode_layer(midi) duration_emb = self.duration_encode_layer(duration_) input_emb = label_emb + midi_emb + duration_emb x_masks = self._source_mask(label_lengths) hs, _ = self.encoder(input_emb, x_masks) # (B, T_text, adim) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(label_lengths).to(input_emb.device) d_outs = self.duration_predictor(hs, d_masks) # (B, T_text) hs = self.length_regulator(hs, ds) # (B, T_feats, adim) # forward decoder if self.reduction_factor > 1: olens_in = olens.new( [ torch.div(olen, self.reduction_factor, rounding_mode="trunc") for olen in olens ] ) else: olens_in = olens h_masks = self._source_mask(olens_in) zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) before_outs, log_f0_outs, vuv_outs = ( self.linear_projection(zs) .view((zs.size(0), -1, self.odim + 2)) .split_with_sizes([self.odim, 1, 1], dim=2) ) # (B. T_feats, odim), (B. T_feats, 1), (B. T_feats, 1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: assert feats_lengths.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = feats_lengths.new( [olen - olen % self.reduction_factor for olen in feats_lengths] ) max_olen = max(olens) ys = feats[:, :max_olen] if self.loss_function == "XiaoiceSing2": log_f0 = log_f0[:, :max_olen] vuv = vuv[:, :max_olen] else: ys = feats olens = feats_lengths ilens = label_lengths if self.loss_function == "FastSpeech1": mel_loss, duration_loss = self.criterion( after_outs, before_outs, d_outs, ys, ds, ilens, olens ) elif self.loss_function == "XiaoiceSing2": mel_loss, duration_loss, pitch_loss, vuv_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=log_f0_outs, v_outs=vuv_outs, ys=ys, ds=ds, ps=log_f0, vs=vuv, ilens=ilens, olens=olens, loss_type=self.loss_type, ) mel_loss = mel_loss * self.lambda_mel duration_loss = duration_loss * self.lambda_dur loss = mel_loss + duration_loss stats = dict(mel_loss=mel_loss.item(), duration_loss=duration_loss.item()) if self.loss_function == "XiaoiceSing2": pitch_loss = pitch_loss * self.lambda_pitch vuv_loss = vuv_loss * self.lambda_vuv loss += pitch_loss + vuv_loss stats["pitch_loss"] = pitch_loss.item() stats["vuv_loss"] = vuv_loss.item() stats["loss"] = loss.item() # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) if joint_training: return loss, stats, after_outs if after_outs is not None else before_outs else: if flag_IsValid is False: return loss, stats, weight else: return loss, stats, weight, after_outs[:, : olens.max()], ys, olens
[docs] def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, label: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, slur: Optional[Dict[str, torch.Tensor]] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, use_teacher_forcing: torch.Tensor = False, joint_training: bool = False, ) -> Dict[str, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T_text,). feats (Optional[Tensor]): Feature sequence to extract style (N, idim). durations (Optional[LongTensor]): Groundtruth of duration (T_text + 1,). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (Tmax). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (Tmax). pitch (FloatTensor): Batch of padded f0 (B, Tmax). duration (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded duration (Tmax). slur (LongTensor): Batch of padded slur (B, Tmax). spembs (Optional[Tensor]): Speaker embedding (spk_embed_dim,). sids (Optional[Tensor]): Speaker ID (1,). lids (Optional[Tensor]): Language ID (1,). alpha (float): Alpha to control the speed. Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). * duration (Tensor): Duration sequence (T_text + 1,). """ label = label["score"] midi = melody["score"] if joint_training: duration_ = duration["lab"] else: duration_ = duration["score_phn"] ds = duration["lab"] label_emb = self.phone_encode_layer(label) midi_emb = self.midi_encode_layer(midi) duration_emb = self.duration_encode_layer(duration_) input_emb = label_emb + midi_emb + duration_emb x_masks = None # self._source_mask(label_lengths) hs, _ = self.encoder(input_emb, x_masks) # (B, T_text, adim) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) if spembs is not None: spembs = spembs.unsqueeze(0) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = None # make_pad_mask(label_lengths).to(input_emb.device) d_outs = self.duration_predictor.inference(hs, d_masks) # (B, T_text) d_outs_int = torch.floor(d_outs + 0.5).to(dtype=torch.long) # (B, T_text) logging.info(f"ds: {ds}") logging.info(f"ds.shape: {ds.shape}") logging.info(f"d_outs: {d_outs}") logging.info(f"d_outs.shape: {d_outs.shape}") # use duration model output hs = self.length_regulator(hs, d_outs_int) # (B, T_feats, adim) # forward decoder h_masks = None # self._source_mask(feats_lengths) zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) before_outs, _, _ = ( self.linear_projection(zs) .view((zs.size(0), -1, self.odim + 2)) .split_with_sizes([self.odim, 1, 1], dim=2) ) # (B, T_feats, odim), (B, T_feats, 1), (B, T_feats, 1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) return dict( feat_gen=after_outs[0], prob=None, att_w=None ) # outs, probs, att_ws
def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, T_text, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, T_text, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters( self, init_type: str, init_enc_alpha: float, init_dec_alpha: float ): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)