Source code for espnet2.gan_svs.vits.vits

# Copyright 2021 Tomoki Hayashi
# Copyright 2022 Yifeng Yu
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""VITS/VISinger module for GAN-SVS task."""

from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any, Dict, Optional

import torch
from torch.nn import functional as F
from typeguard import typechecked

from espnet2.gan_svs.abs_gan_svs import AbsGANSVS
from espnet2.gan_svs.avocodo.avocodo import (
    SBD,
    AvocodoDiscriminator,
    AvocodoDiscriminatorPlus,
    CoMBD,
)
from espnet2.gan_svs.visinger2.visinger2_vocoder import VISinger2Discriminator
from espnet2.gan_svs.vits.generator import VISingerGenerator
from espnet2.gan_tts.hifigan import (
    HiFiGANMultiPeriodDiscriminator,
    HiFiGANMultiScaleDiscriminator,
    HiFiGANMultiScaleMultiPeriodDiscriminator,
    HiFiGANPeriodDiscriminator,
    HiFiGANScaleDiscriminator,
)
from espnet2.gan_tts.hifigan.loss import (
    DiscriminatorAdversarialLoss,
    FeatureMatchLoss,
    GeneratorAdversarialLoss,
    MelSpectrogramLoss,
)
from espnet2.gan_tts.utils import get_segments
from espnet2.gan_tts.vits.loss import KLDivergenceLoss, KLDivergenceLossWithoutFlow
from espnet2.torch_utils.device_funcs import force_gatherable

AVAILABLE_GENERATERS = {
    "visinger": VISingerGenerator,
    # TODO(yifeng): add more generators
    "visinger2": VISingerGenerator,
    # "pisinger": PISingerGenerator,
}
AVAILABLE_DISCRIMINATORS = {
    "hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
    "hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
    "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
    "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
    "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator,  # NOQA
    "combd": CoMBD,
    "sbd": SBD,
    "avocodo": AvocodoDiscriminator,
    "visinger2": VISinger2Discriminator,
    "avocodo_plus": AvocodoDiscriminatorPlus,
}

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):  # NOQA
        yield


[docs]class VITS(AbsGANSVS): """VITS module (generator + discriminator). This is a module of VITS described in `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`: https://arxiv.org/abs/2006.04558 """ @typechecked def __init__( self, # generator related idim: int, odim: int, sampling_rate: int = 22050, generator_type: str = "visinger", vocoder_generator_type: str = "hifigan", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, "langs": None, "spk_embed_dim": None, "global_channels": -1, "segment_size": 32, "text_encoder_attention_heads": 2, "text_encoder_ffn_expand": 4, "text_encoder_blocks": 6, "text_encoder_positionwise_layer_type": "conv1d", "text_encoder_positionwise_conv_kernel_size": 1, "text_encoder_positional_encoding_layer_type": "rel_pos", "text_encoder_self_attention_layer_type": "rel_selfattn", "text_encoder_activation_type": "swish", "text_encoder_normalize_before": True, "text_encoder_dropout_rate": 0.1, "text_encoder_positional_dropout_rate": 0.0, "text_encoder_attention_dropout_rate": 0.0, "text_encoder_conformer_kernel_size": 7, "use_macaron_style_in_text_encoder": True, "use_conformer_conv_in_text_encoder": True, "decoder_kernel_size": 7, "decoder_channels": 512, "decoder_upsample_scales": [8, 8, 2, 2], "decoder_upsample_kernel_sizes": [16, 16, 4, 4], "decoder_resblock_kernel_sizes": [3, 7, 11], "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "projection_filters": [0, 1, 1, 1], "projection_kernels": [0, 5, 7, 11], "use_weight_norm_in_decoder": True, "posterior_encoder_kernel_size": 5, "posterior_encoder_layers": 16, "posterior_encoder_stacks": 1, "posterior_encoder_base_dilation": 1, "posterior_encoder_dropout_rate": 0.0, "use_weight_norm_in_posterior_encoder": True, "flow_flows": 4, "flow_kernel_size": 5, "flow_base_dilation": 1, "flow_layers": 4, "flow_dropout_rate": 0.0, "use_weight_norm_in_flow": True, "use_only_mean_in_flow": True, "expand_f0_method": "repeat", "use_phoneme_predictor": False, "hubert_channels": 0, }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any] = { "hifigan_multi_scale_multi_period_discriminator": { "scales": 1, "scale_downsample_pooling": "AvgPool1d", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, }, "avocodo": { "combd": { "combd_h_u": [ [16, 64, 256, 1024, 1024, 1024], [16, 64, 256, 1024, 1024, 1024], [16, 64, 256, 1024, 1024, 1024], ], "combd_d_k": [ [7, 11, 11, 11, 11, 5], [11, 21, 21, 21, 21, 5], [15, 41, 41, 41, 41, 5], ], "combd_d_s": [ [1, 1, 4, 4, 4, 1], [1, 1, 4, 4, 4, 1], [1, 1, 4, 4, 4, 1], ], "combd_d_d": [ [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], ], "combd_d_g": [ [1, 4, 16, 64, 256, 1], [1, 4, 16, 64, 256, 1], [1, 4, 16, 64, 256, 1], ], "combd_d_p": [ [3, 5, 5, 5, 5, 2], [5, 10, 10, 10, 10, 2], [7, 20, 20, 20, 20, 2], ], "combd_op_f": [1, 1, 1], "combd_op_k": [3, 3, 3], "combd_op_g": [1, 1, 1], }, "sbd": { "use_sbd": True, "sbd_filters": [ [64, 128, 256, 256, 256], [64, 128, 256, 256, 256], [64, 128, 256, 256, 256], [32, 64, 128, 128, 128], ], "sbd_strides": [ [1, 1, 3, 3, 1], [1, 1, 3, 3, 1], [1, 1, 3, 3, 1], [1, 1, 3, 3, 1], ], "sbd_kernel_sizes": [ [[7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7]], [[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]], [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], [[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]], ], "sbd_dilations": [ [[5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11]], [[3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7]], [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3], [2, 3, 5], [2, 3, 5]], ], "sbd_band_ranges": [[0, 6], [0, 11], [0, 16], [0, 64]], "sbd_transpose": [False, False, False, True], "pqmf_config": { "sbd": [16, 256, 0.03, 10.0], "fsbd": [64, 256, 0.1, 9.0], }, }, "pqmf_config": { "lv1": [2, 256, 0.25, 10.0], "lv2": [4, 192, 0.13, 10.0], }, }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, discriminator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, feat_match_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "average_by_layers": False, "include_final_outputs": True, }, mel_loss_params: Dict[str, Any] = { "fs": 22050, "n_fft": 1024, "hop_length": 256, "win_length": None, "window": "hann", "n_mels": 80, "fmin": 0, "fmax": None, "log_base": None, }, lambda_adv: float = 1.0, lambda_mel: float = 45.0, lambda_feat_match: float = 2.0, lambda_dur: float = 0.1, lambda_kl: float = 1.0, lambda_pitch: float = 10.0, lambda_phoneme: float = 1.0, lambda_c_yin: float = 45.0, cache_generator_outputs: bool = True, ): """Initialize VITS module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. generator_type (str): Generator type. vocoder_generator_type (str): Type of vocoder generator to use in the model. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator adversarial loss. discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for discriminator adversarial loss. feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. lambda_adv (float): Loss scaling coefficient for adversarial loss. lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. lambda_feat_match (float): Loss scaling coefficient for feat match loss. lambda_dur (float): Loss scaling coefficient for duration loss. lambda_kl (float): Loss scaling coefficient for KL divergence loss. lambda_pitch (float): Loss scaling coefficient for pitch loss. lambda_phoneme (float): Loss scaling coefficient for phoneme loss. lambda_c_yin (float): Loss scaling coefficient for yin loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ super().__init__() # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if "visinger" in generator_type or "pisinger" in generator_type: # NOTE(kan-bayashi): Update parameters for the compatibility. # The idim and odim is automatically decided from input data, # where idim represents #vocabularies and odim represents # the input acoustic feature dimension. generator_params.update(vocabs=idim, aux_channels=odim) self.generator_type = generator_type self.use_flow = True if generator_params["flow_flows"] > 0 else False self.use_phoneme_predictor = generator_params["use_phoneme_predictor"] self.discriminator_type = discriminator_type if "avocodo" in discriminator_type: use_avocodo = True vocoder_generator_type = "avocodo" else: use_avocodo = False self.use_avocodo = use_avocodo self.vocoder_generator_type = vocoder_generator_type generator_params.update(generator_type=generator_type) generator_params.update(vocoder_generator_type=vocoder_generator_type) generator_params.update(fs=mel_loss_params["fs"]) generator_params.update(hop_length=mel_loss_params["hop_length"]) generator_params.update(win_length=mel_loss_params["win_length"]) generator_params.update(n_fft=mel_loss_params["n_fft"]) if vocoder_generator_type == "uhifigan" and use_avocodo: generator_params.update(use_avocodo=use_avocodo) self.generator = generator_class( **generator_params, ) discriminator_class = AVAILABLE_DISCRIMINATORS[self.discriminator_type] if use_avocodo: discriminator_params.update( projection_filters=generator_params["projection_filters"] ) discriminator_params["sbd"].update( segment_size=generator_params["segment_size"] * mel_loss_params["hop_length"] ) if "visinger2" in discriminator_type: discriminator_params["multi_freq_disc_params"].update( sample_rate=sampling_rate ) self.discriminator = discriminator_class( **discriminator_params, ) self.generator_adv_loss = GeneratorAdversarialLoss( **generator_adv_loss_params, ) self.discriminator_adv_loss = DiscriminatorAdversarialLoss( **discriminator_adv_loss_params, ) self.feat_match_loss = FeatureMatchLoss( **feat_match_loss_params, ) self.mel_loss = MelSpectrogramLoss( **mel_loss_params, ) if self.use_flow: self.kl_loss = KLDivergenceLoss() else: self.kl_loss = KLDivergenceLossWithoutFlow() self.ctc_loss = torch.nn.CTCLoss(idim - 1, reduction="mean") self.mse_loss = torch.nn.MSELoss() # coefficients self.lambda_adv = lambda_adv self.lambda_mel = lambda_mel self.lambda_kl = lambda_kl self.lambda_feat_match = lambda_feat_match self.lambda_dur = lambda_dur self.lambda_pitch = lambda_pitch self.lambda_phoneme = lambda_phoneme self.lambda_c_yin = lambda_c_yin # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator.spks self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim # hubert alignment self.adaptive_pool = torch.nn.AdaptiveAvgPool1d(1) self.n_mels = mel_loss_params["n_mels"] @property def require_raw_singing(self): """Return whether or not singing is required.""" return True @property def require_vocoder(self): """Return whether or not vocoder is required.""" return False
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, ssl_feats: torch.Tensor = None, ssl_feats_lengths: torch.Tensor = None, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, pitch: torch.LongTensor = None, ying: torch.Tensor = None, duration: Optional[Dict[str, torch.Tensor]] = None, slur: torch.LongTensor = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, forward_generator: bool = True, ) -> Dict[str, Any]: """Perform generator forward. Args: text (LongTensor): Batch of padded character ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). ssl_feats (Tensor): SSL feature tensor (B, T_feats, hubert_channels). ssl_feats_lengths (Tensor): SSL feature length tensor (B,). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, T_text). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, T_text). pitch (FloatTensor): Batch of padded f0 (B, T_feats). ying (Optional[Tensor]): Batch of padded ying (B, T_feats). duration (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded duration (B, T_text). slur (FloatTensor): Batch of padded slur (B, T_text). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). forward_generator (bool): Whether to forward generator. Returns: Dict[str, Any]: - loss (Tensor): Loss scalar tensor. - stats (Dict[str, float]): Statistics to be monitored. - weight (Tensor): Weight tensor to summarize losses. - optim_idx (int): Optimizer index (0 for G and 1 for D). """ if ssl_feats is not None: if ssl_feats.shape[1] > feats.shape[1]: ssl_feats = ssl_feats[:, : feats.shape[1], :] elif ssl_feats.shape[1] < feats.shape[1]: padding = (0, 0, 0, feats.shape[1] - ssl_feats.shape[1], 0, 0) ssl_feats = torch.nn.functional.pad(ssl_feats, padding) concatenated_feats = torch.cat([feats, ssl_feats], dim=2) else: concatenated_feats = feats score_dur = duration["score_syb"] gt_dur = duration["lab"] label = label["lab"] label_lengths = label_lengths["lab"] melody = melody["lab"] if forward_generator: return self._forward_generator( text=text, text_lengths=text_lengths, feats=concatenated_feats, feats_lengths=feats_lengths, singing=singing, singing_lengths=singing_lengths, label=label, label_lengths=label_lengths, melody=melody, gt_dur=gt_dur, score_dur=score_dur, slur=slur, pitch=pitch, ying=ying, sids=sids, spembs=spembs, lids=lids, ) else: return self._forward_discrminator( text=text, text_lengths=text_lengths, feats=concatenated_feats, feats_lengths=feats_lengths, singing=singing, singing_lengths=singing_lengths, label=label, label_lengths=label_lengths, melody=melody, gt_dur=gt_dur, score_dur=score_dur, slur=slur, pitch=pitch, ying=ying, sids=sids, spembs=spembs, lids=lids, )
def _forward_generator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, label: torch.Tensor = None, label_lengths: torch.Tensor = None, melody: torch.Tensor = None, gt_dur: torch.Tensor = None, score_dur: torch.Tensor = None, slur: torch.Tensor = None, pitch: torch.Tensor = None, ying: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). label (Tensor): Label index tensor (B, T_text). label_lengths (Tensor): Label length tensor (B,). melody (Tensor): Melody index tensor (B, T_text). gt_dur (Tensor): Groundtruth duration tensor (B, T_text). score_dur (Tensor): Score duration tensor (B, T_text). slur (Tensor): Slur index tensor (B, T_text). pitch (FloatTensor): Batch of padded f0 (B, T_feats). ying (Optional[Tensor]): Yin pitch tensor (B, T_feats). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) singing = singing.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, label=label, label_lengths=label_lengths, melody=melody, gt_dur=gt_dur, score_dur=score_dur, slur=slur, pitch=pitch, ying=ying, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs if "visinger" in self.generator_type: singing_hat_, start_idxs, _, z_mask, outs_, *extra_outs = outs if ( self.vocoder_generator_type == "visinger2" and self.generator_type == "visinger2" ): singing_hat_ddsp_, predict_mel = extra_outs elif self.vocoder_generator_type == "visinger2": singing_hat_ddsp_ = extra_outs[0] elif self.generator_type == "visinger2": predict_mel = extra_outs[0] elif "pisinger" in self.generator_type: if self.vocoder_generator_type == "visinger2": ( singing_hat_, start_idxs, _, z_mask, outs_, singing_hat_ddsp_, outs2_, ) = outs else: singing_hat_, start_idxs, _, z_mask, outs_, outs2_ = outs ( yin_gt_crop, yin_gt_shifted_crop, yin_dec_crop, z_yin_crop_shifted, scope_shift, ) = outs2_ ( _, z_p, m_p, logs_p, m_q, logs_q, pred_pitch, gt_pitch, pred_dur, gt_dur, log_probs, ) = outs_ singing_ = get_segments( x=singing, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs if "avocodo" in self.discriminator_type: p, p_hat, fmaps_real, fmaps_fake = self.discriminator( singing_, singing_hat_ ) else: p_hat = self.discriminator(singing_hat_) with torch.no_grad(): # do not store discriminator gradient in generator turn p = self.discriminator(singing_) # calculate losses with autocast(enabled=False): if "pisinger" in self.generator_type: yin_dec_loss = ( F.l1_loss(yin_gt_shifted_crop, yin_dec_crop) * self.lambda_c_yin ) # TODO(yifeng): add yin shift loss later # loss_yin_shift = ( # F.l1_loss(torch.exp(-yin_gt_crop), torch.exp(-yin_hat_crop)) # * self.lambda_c_yin # + F.l1_loss( # torch.exp(-yin_hat_shifted), # torch.exp(-(torch.chunk(yin_hat_crop, 2, dim=0)[1])), # ) # * self.lambda_c_yin # ) if self.use_avocodo: mel_loss = self.mel_loss(singing_hat_[-1], singing_) elif self.vocoder_generator_type == "visinger2": mel_loss = self.mel_loss(singing_hat_, singing_) ddsp_mel_loss = self.mel_loss(singing_hat_ddsp_, singing_) else: mel_loss = self.mel_loss(singing_hat_, singing_) if self.use_flow: kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) else: kl_loss = self.kl_loss(m_q, logs_q, m_p, logs_p) if "avocodo" in self.discriminator_type: adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(fmaps_fake, fmaps_real) else: adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) pitch_loss = self.mse_loss(pred_pitch, gt_pitch) phoneme_dur_loss = self.mse_loss( pred_dur[:, 0, :].squeeze(1), gt_dur.float() ) score_dur_loss = self.mse_loss(pred_dur[:, 1, :].squeeze(1), gt_dur.float()) if self.use_phoneme_predictor: ctc_loss = self.ctc_loss(log_probs, label, feats_lengths, label_lengths) mel_loss = mel_loss * self.lambda_mel kl_loss = kl_loss * self.lambda_kl adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match pitch_loss = pitch_loss * self.lambda_pitch phoneme_dur_loss = phoneme_dur_loss * self.lambda_dur score_dur_loss = score_dur_loss * self.lambda_dur if self.use_phoneme_predictor: ctc_loss = ctc_loss * self.lambda_phoneme loss = mel_loss + kl_loss + adv_loss + feat_match_loss if self.vocoder_generator_type == "visinger2": ddsp_mel_loss = ddsp_mel_loss * self.lambda_mel loss = loss + ddsp_mel_loss if self.generator_type == "visinger2": feats = feats[:, : self.n_mels, :] loss_mel_am = self.mse_loss(feats * z_mask, predict_mel * z_mask) loss = loss + loss_mel_am loss = loss + pitch_loss loss = loss + phoneme_dur_loss loss = loss + score_dur_loss if self.use_phoneme_predictor: loss = loss + ctc_loss if "pisinger" in self.generator_type: loss = loss + yin_dec_loss stats = dict( generator_loss=loss.item(), generator_mel_loss=mel_loss.item(), generator_phn_dur_loss=phoneme_dur_loss.item(), generator_score_dur_loss=score_dur_loss.item(), generator_adv_loss=adv_loss.item(), generator_feat_match_loss=feat_match_loss.item(), generator_pitch_loss=pitch_loss.item(), generator_kl_loss=kl_loss.item(), ) if self.use_phoneme_predictor: stats.update( dict( generator_phoneme_loss=ctc_loss.item(), ) ) if self.vocoder_generator_type == "visinger2": stats.update( dict( generator_mel_ddsp_loss=ddsp_mel_loss.item(), ) ) if self.generator_type == "visinger2": stats.update( dict( generator_mel_am_loss=loss_mel_am.item(), ) ) if "pisinger" in self.generator_type: stats.update( dict( generator_yin_dec_loss=yin_dec_loss.item(), ) ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 0, # needed for trainer } def _forward_discrminator( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, label: torch.Tensor = None, label_lengths: torch.Tensor = None, melody: torch.Tensor = None, gt_dur: torch.Tensor = None, score_dur: torch.Tensor = None, slur: torch.Tensor = None, pitch: torch.Tensor = None, ying: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). label (Tensor): Label index tensor (B, T_text). label_lengths (Tensor): Label length tensor (B,). melody (Tensor): Melody index tensor (B, T_text). gt_dur (Tensor): Groundtruth duration tensor (B, T_text). score_dur (Tensor): Score duration tensor (B, T_text). slur (Tensor): Slur index tensor (B, T_text). pitch (FloatTensor): Batch of padded f0 (B, T_feats). ying (Optional[Tensor]): Yin pitch tensor (B, T_feats). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ # setup batch_size = text.size(0) feats = feats.transpose(1, 2) singing = singing.unsqueeze(1) # calculate generator outputs reuse_cache = True if not self.cache_generator_outputs or self._cache is None: reuse_cache = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, gt_dur=gt_dur, label=label, label_lengths=label_lengths, melody=melody, score_dur=score_dur, slur=slur, pitch=pitch, ying=ying, sids=sids, spembs=spembs, lids=lids, ) else: outs = self._cache # store cache if self.cache_generator_outputs and not reuse_cache: self._cache = outs # parse outputs # remove dp loss singing_hat_, start_idxs, *_ = outs singing_ = get_segments( x=singing, start_idxs=start_idxs * self.generator.upsample_factor, segment_size=self.generator.segment_size * self.generator.upsample_factor, ) # calculate discriminator outputs if "avocodo" in self.discriminator_type: detached_singing_hat_ = [x.detach() for x in singing_hat_] p, p_hat, fmaps_real, fmaps_fake = self.discriminator( singing_, detached_singing_hat_ ) else: p_hat = self.discriminator(singing_hat_.detach()) p = self.discriminator(singing_) # calculate losses with autocast(enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(), discriminator_real_loss=real_loss.item(), discriminator_fake_loss=fake_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) # reset cache if reuse_cache or not self.training: self._cache = None return { "loss": loss, "stats": stats, "weight": weight, "optim_idx": 1, # needed for trainer }
[docs] def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, ssl_feats: Optional[torch.Tensor] = None, label: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, slur: Optional[Dict[str, torch.Tensor]] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (T_text,). feats (Tensor): Feature tensor (T_feats, aux_channels). ssl_feats (Tensor): SSL Feature tensor (T_feats, hubert_channels). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, T_text). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, T_text). pitch (FloatTensor): Batch of padded f0 (B, T_feats). slur (LongTensor): Batch of padded slur (B, T_text). sids (Tensor): Speaker index tensor (1,). spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). lids (Tensor): Language index tensor (1,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated singing. max_len (Optional[int]): Maximum length. use_teacher_forcing (bool): Whether to use teacher forcing. duration (Optional[Dict]): key is "lab", "score_phn" or "score_syb"; value (LongTensor): Batch of padded duration (B, T_text). Returns: Dict[str, Tensor]: * wav (Tensor): Generated waveform tensor (T_wav,). """ # setup label = label["lab"] melody = melody["lab"] score_dur = duration["score_syb"] gt_dur = duration["lab"] text = text[None] text_lengths = torch.tensor( [text.size(1)], dtype=torch.long, device=text.device, ) label_lengths = torch.tensor( [label.size(1)], dtype=torch.long, device=text.device, ) if sids is not None: sids = sids.view(1) if lids is not None: lids = lids.view(1) # inference if use_teacher_forcing: assert feats is not None assert pitch is not None if ssl_feats is not None: if ssl_feats.shape[0] > feats.shape[0]: ssl_feats = ssl_feats[: feats.shape[0], :] elif ssl_feats.shape[0] < feats.shape[0]: padding = (0, 0, feats.shape[0] - ssl_feats.shape[0], 0, 0) ssl_feats = torch.nn.functional.pad(ssl_feats, padding) feats = torch.cat([feats, ssl_feats], dim=1) feats = feats[None].transpose(1, 2) feats_lengths = torch.tensor( [feats.size(2)], dtype=torch.long, device=feats.device, ) wav = self.generator.inference( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, label=label, label_lengths=label_lengths, melody=melody, score_dur=score_dur, slur=slur, gt_dur=gt_dur, pitch=pitch, sids=sids, spembs=spembs, lids=lids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, max_len=max_len, use_teacher_forcing=use_teacher_forcing, ) else: wav = self.generator.inference( text=text, text_lengths=text_lengths, label=label, label_lengths=label_lengths, melody=melody, score_dur=score_dur, slur=slur, sids=sids, spembs=spembs, lids=lids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, max_len=max_len, ) return dict(wav=wav.view(-1))