Source code for espnet2.gan_tts.hifigan.loss

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""HiFiGAN-related loss modules.

This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.

"""

from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F

from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank


[docs]class GeneratorAdversarialLoss(torch.nn.Module): """Generator adversarial loss module.""" def __init__( self, average_by_discriminators: bool = True, loss_type: str = "mse", ): """Initialize GeneratorAversarialLoss module. Args: average_by_discriminators (bool): Whether to average the loss by the number of discriminators. loss_type (str): Loss type, "mse" or "hinge". """ super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.criterion = self._mse_loss else: self.criterion = self._hinge_loss
[docs] def forward( self, outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """Calcualate generator adversarial loss. Args: outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs.. Returns: Tensor: Generator adversarial loss value. """ if isinstance(outputs, (tuple, list)): adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_ = outputs_[-1] adv_loss += self.criterion(outputs_) if self.average_by_discriminators: adv_loss /= i + 1 else: adv_loss = self.criterion(outputs) return adv_loss
def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) def _hinge_loss(self, x): return -x.mean()
[docs]class DiscriminatorAdversarialLoss(torch.nn.Module): """Discriminator adversarial loss module.""" def __init__( self, average_by_discriminators: bool = True, loss_type: str = "mse", ): """Initialize DiscriminatorAversarialLoss module. Args: average_by_discriminators (bool): Whether to average the loss by the number of discriminators. loss_type (str): Loss type, "mse" or "hinge". """ super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.fake_criterion = self._mse_fake_loss self.real_criterion = self._mse_real_loss else: self.fake_criterion = self._hinge_fake_loss self.real_criterion = self._hinge_real_loss
[docs] def forward( self, outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Calcualate discriminator adversarial loss. Args: outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from generator. outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from groundtruth. Returns: Tensor: Discriminator real loss value. Tensor: Discriminator fake loss value. """ if isinstance(outputs, (tuple, list)): real_loss = 0.0 fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_hat_ = outputs_hat_[-1] outputs_ = outputs_[-1] real_loss += self.real_criterion(outputs_) fake_loss += self.fake_criterion(outputs_hat_) if self.average_by_discriminators: fake_loss /= i + 1 real_loss /= i + 1 else: real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_zeros(x.size())) def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
[docs]class FeatureMatchLoss(torch.nn.Module): """Feature matching loss module.""" def __init__( self, average_by_layers: bool = True, average_by_discriminators: bool = True, include_final_outputs: bool = False, ): """Initialize FeatureMatchLoss module. Args: average_by_layers (bool): Whether to average the loss by the number of layers. average_by_discriminators (bool): Whether to average the loss by the number of discriminators. include_final_outputs (bool): Whether to include the final output of each discriminator for loss calculation. """ super().__init__() self.average_by_layers = average_by_layers self.average_by_discriminators = average_by_discriminators self.include_final_outputs = include_final_outputs
[docs] def forward( self, feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], ) -> torch.Tensor: """Calculate feature matching loss. Args: feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from generator's outputs. feats (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from groundtruth.. Returns: Tensor: Feature matching loss value. """ feat_match_loss = 0.0 for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): feat_match_loss_ = 0.0 if not self.include_final_outputs: feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 return feat_match_loss
[docs]class MelSpectrogramLoss(torch.nn.Module): """Mel-spectrogram loss.""" def __init__( self, fs: int = 22050, n_fft: int = 1024, hop_length: int = 256, win_length: Optional[int] = None, window: str = "hann", n_mels: int = 80, fmin: Optional[int] = 0, fmax: Optional[int] = None, center: bool = True, normalized: bool = False, onesided: bool = True, log_base: Optional[float] = 10.0, ): """Initialize Mel-spectrogram loss. Args: fs (int): Sampling rate. n_fft (int): FFT points. hop_length (int): Hop length. win_length (Optional[int]): Window length. window (str): Window type. n_mels (int): Number of Mel basis. fmin (Optional[int]): Minimum frequency for Mel. fmax (Optional[int]): Maximum frequency for Mel. center (bool): Whether to use center window. normalized (bool): Whether to use normalized one. onesided (bool): Whether to use oneseded one. log_base (Optional[float]): Log base value. """ super().__init__() self.wav_to_mel = LogMelFbank( fs=fs, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, n_mels=n_mels, fmin=fmin, fmax=fmax, center=center, normalized=normalized, onesided=onesided, log_base=log_base, )
[docs] def forward( self, y_hat: torch.Tensor, y: torch.Tensor, spec: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Calculate Mel-spectrogram loss. Args: y_hat (Tensor): Generated waveform tensor (B, 1, T). y (Tensor): Groundtruth waveform tensor (B, 1, T). spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth waveform. Returns: Tensor: Mel-spectrogram loss value. """ mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1)) if spec is None: mel, _ = self.wav_to_mel(y.squeeze(1)) else: mel, _ = self.wav_to_mel.logmel(spec) mel_loss = F.l1_loss(mel_hat, mel) return mel_loss