Source code for espnet2.gan_tts.espnet_model

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""GAN-based text-to-speech ESPnet model."""

from contextlib import contextmanager
from typing import Any, Dict, Optional

import torch
from packaging.version import parse as V
from typeguard import typechecked

from espnet2.gan_tts.abs_gan_tts import AbsGANTTS
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet2.train.abs_gan_espnet_model import AbsGANESPnetModel
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch < 1.6.0
    @contextmanager
    def autocast(enabled=True):  # NOQA
        yield


[docs]class ESPnetGANTTSModel(AbsGANESPnetModel): """ESPnet model for GAN-based text-to-speech task.""" @typechecked def __init__( self, feats_extract: Optional[AbsFeatsExtract], normalize: Optional[AbsNormalize and InversibleInterface], pitch_extract: Optional[AbsFeatsExtract], pitch_normalize: Optional[AbsNormalize and InversibleInterface], energy_extract: Optional[AbsFeatsExtract], energy_normalize: Optional[AbsNormalize and InversibleInterface], tts: AbsGANTTS, ): """Initialize ESPnetGANTTSModel module.""" super().__init__() self.feats_extract = feats_extract self.normalize = normalize self.pitch_extract = pitch_extract self.pitch_normalize = pitch_normalize self.energy_extract = energy_extract self.energy_normalize = energy_normalize self.tts = tts assert hasattr( tts, "generator" ), "generator module must be registered as tts.generator" assert hasattr( tts, "discriminator" ), "discriminator module must be registered as tts.discriminator" if feats_extract is not None: if hasattr(tts.generator, "vocoder"): upsample_factor = tts.generator["vocoder"].upsample_factor else: upsample_factor = tts.generator.upsample_factor assert ( feats_extract.get_parameters()["n_shift"] == upsample_factor ), "n_shift must be equal to upsample_factor"
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: Optional[torch.Tensor] = None, durations_lengths: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, energy_lengths: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, forward_generator: bool = True, **kwargs, ) -> Dict[str, Any]: """Return generator or discriminator loss with dict format. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). duration (Optional[Tensor]): Duration tensor. duration_lengths (Optional[Tensor]): Duration length tensor (B,). pitch (Optional[Tensor]): Pitch tensor. pitch_lengths (Optional[Tensor]): Pitch length tensor (B,). energy (Optional[Tensor]): Energy tensor. energy_lengths (Optional[Tensor]): Energy length tensor (B,). spembs (Optional[Tensor]): Speaker embedding tensor (B, D). sids (Optional[Tensor]): Speaker ID tensor (B, 1). lids (Optional[Tensor]): Language ID tensor (B, 1). forward_generator (bool): Whether to forward generator. kwargs: "utt_id" is among the input. Returns: Dict[str, Any]: - loss (Tensor): Loss scalar tensor. - stats (Dict[str, float]): Statistics to be monitored. - weight (Tensor): Weight tensor to summarize losses. - optim_idx (int): Optimizer index (0 for G and 1 for D). """ with autocast(False): # Extract features feats = None if self.feats_extract is not None: feats, feats_lengths = self.feats_extract( speech, speech_lengths, ) if self.pitch_extract is not None and pitch is None: pitch, pitch_lengths = self.pitch_extract( speech, speech_lengths, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, ) if self.energy_extract is not None and energy is None: energy, energy_lengths = self.energy_extract( speech, speech_lengths, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, ) # Normalize if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) if self.pitch_normalize is not None: pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths) if self.energy_normalize is not None: energy, energy_lengths = self.energy_normalize(energy, energy_lengths) # Make batch for tts inputs batch = dict( text=text, text_lengths=text_lengths, forward_generator=forward_generator, ) # Update batch for additional auxiliary inputs if feats is not None: batch.update(feats=feats, feats_lengths=feats_lengths) if self.tts.require_raw_speech: batch.update(speech=speech, speech_lengths=speech_lengths) if durations is not None: batch.update(durations=durations, durations_lengths=durations_lengths) if self.pitch_extract is not None and pitch is not None: batch.update(pitch=pitch, pitch_lengths=pitch_lengths) if self.energy_extract is not None and energy is not None: batch.update(energy=energy, energy_lengths=energy_lengths) if spembs is not None: batch.update(spembs=spembs) if sids is not None: batch.update(sids=sids) if lids is not None: batch.update(lids=lids) return self.tts(**batch)
[docs] def collect_feats( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: Optional[torch.Tensor] = None, durations_lengths: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, energy_lengths: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """Calculate features and return them as a dict. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B, 1). durations (Optional[Tensor): Duration tensor. durations_lengths (Optional[Tensor): Duration length tensor (B,). pitch (Optional[Tensor): Pitch tensor. pitch_lengths (Optional[Tensor): Pitch length tensor (B,). energy (Optional[Tensor): Energy tensor. energy_lengths (Optional[Tensor): Energy length tensor (B,). spembs (Optional[Tensor]): Speaker embedding tensor (B, D). sids (Optional[Tensor]): Speaker index tensor (B, 1). lids (Optional[Tensor]): Language ID tensor (B, 1). Returns: Dict[str, Tensor]: Dict of features. """ feats = None if self.feats_extract is not None: feats, feats_lengths = self.feats_extract( speech, speech_lengths, ) if self.pitch_extract is not None: pitch, pitch_lengths = self.pitch_extract( speech, speech_lengths, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, ) if self.energy_extract is not None: energy, energy_lengths = self.energy_extract( speech, speech_lengths, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, ) # store in dict feats_dict = {} if feats is not None: feats_dict.update(feats=feats, feats_lengths=feats_lengths) if pitch is not None: feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths) if energy is not None: feats_dict.update(energy=energy, energy_lengths=energy_lengths) return feats_dict