Source code for espnet2.tasks.gan_svs

# Copyright 2021 Tomoki Hayashi
# Copyright 2022 Yifeng Yu
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""GAN-based Singing-voice-synthesis task."""

import argparse
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
import torch
from typeguard import typechecked

from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.gan_svs.abs_gan_svs import AbsGANSVS
from espnet2.gan_svs.espnet_model import ESPnetGANSVSModel
from espnet2.gan_svs.joint import JointScore2Wav
from espnet2.gan_svs.post_frontend.fused import FusedPostFrontends
from espnet2.gan_svs.post_frontend.s3prl import S3prlPostFrontend
from espnet2.gan_svs.vits import VITS
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.global_mvn import GlobalMVN
from espnet2.layers.utterance_mvn import UtteranceMVN
from espnet2.svs.feats_extract.score_feats_extract import (
    FrameScoreFeats,
    SyllableScoreFeats,
)
from espnet2.tasks.abs_task import AbsTask, optim_classes
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.gan_trainer import GANTrainer
from espnet2.train.preprocessor import SVSPreprocessor
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
from espnet2.tts.feats_extract.dio import Dio
from espnet2.tts.feats_extract.energy import Energy
from espnet2.tts.feats_extract.linear_spectrogram import LinearSpectrogram
from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank
from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram
from espnet2.tts.feats_extract.ying import Ying
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none

postfrontend_choices = ClassChoices(
    name="postfrontend",
    classes=dict(
        s3prl=S3prlPostFrontend,
        fused=FusedPostFrontends,
    ),
    type_check=AbsFrontend,
    default=None,
)
feats_extractor_choices = ClassChoices(
    "feats_extract",
    classes=dict(
        fbank=LogMelFbank,
        log_spectrogram=LogSpectrogram,
        linear_spectrogram=LinearSpectrogram,
    ),
    type_check=AbsFeatsExtract,
    default="linear_spectrogram",
)

score_feats_extractor_choices = ClassChoices(
    "score_feats_extract",
    classes=dict(
        frame_score_feats=FrameScoreFeats, syllable_score_feats=SyllableScoreFeats
    ),
    type_check=AbsFeatsExtract,
    default="frame_score_feats",
)

pitch_extractor_choices = ClassChoices(
    "pitch_extract",
    classes=dict(dio=Dio),
    type_check=AbsFeatsExtract,
    default=None,
    optional=True,
)
ying_extractor_choices = ClassChoices(
    "ying_extract",
    classes=dict(ying=Ying),
    type_check=AbsFeatsExtract,
    default=None,
    optional=True,
)
energy_extractor_choices = ClassChoices(
    "energy_extract",
    classes=dict(energy=Energy),
    type_check=AbsFeatsExtract,
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
pitch_normalize_choices = ClassChoices(
    "pitch_normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
energy_normalize_choices = ClassChoices(
    "energy_normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
svs_choices = ClassChoices(
    "svs",
    classes=dict(
        vits=VITS,
        joint_score2wav=JointScore2Wav,
    ),
    type_check=AbsGANSVS,
    default="vits",
)


[docs]class GANSVSTask(AbsTask): """GAN-based Singing-voice-synthesis task.""" # GAN requires two optimizers num_optimizers: int = 2 # Add variable objects configurations class_choices_list = [ # --postfrontend and --postfrontend_conf postfrontend_choices, # --score_extractor and --score_extractor_conf score_feats_extractor_choices, # --feats_extractor and --feats_extractor_conf feats_extractor_choices, # --normalize and --normalize_conf normalize_choices, # --svs and --svs_conf svs_choices, # --pitch_extract and --pitch_extract_conf pitch_extractor_choices, # --pitch_normalize and --pitch_normalize_conf pitch_normalize_choices, # --ying_extract and --ying_extract_conf ying_extractor_choices, # --energy_extract and --energy_extract_conf energy_extractor_choices, # --energy_normalize and --energy_normalize_conf energy_normalize_choices, ] # Use GANTrainer instead of Trainer trainer = GANTrainer
[docs] @classmethod @typechecked def add_task_arguments(cls, parser: argparse.ArgumentParser): # NOTE(kamo): Use '_' instead of '-' to avoid confusion group = parser.add_argument_group(description="Task related") # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as required = parser.get_default("required") required += ["token_list"] group.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group.add_argument( "--token_list", type=str_or_none, default=None, help="A text mapping int-id to token", ) group.add_argument( "--odim", type=int_or_none, default=None, help="The number of dimension of output feature", ) group.add_argument( "--model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetGANSVSModel), help="The keyword arguments for model class.", ) group = parser.add_argument_group(description="Preprocess related") group.add_argument( "--use_preprocessor", type=str2bool, default=True, help="Apply preprocessing to data or not", ) group.add_argument( "--token_type", type=str, default="phn", choices=["bpe", "char", "word", "phn"], help="The text will be tokenized in the specified level token", ) group.add_argument( "--bpemodel", type=str_or_none, default=None, help="The model file of sentencepiece", ) parser.add_argument( "--non_linguistic_symbols", type=str_or_none, help="non_linguistic_symbols file path", ) parser.add_argument( "--cleaner", type=str_or_none, choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"], default=None, help="Apply text cleaning", ) parser.add_argument( "--g2p", type=str_or_none, choices=g2p_choices, default=None, help="Specify g2p method if --token_type=phn", ) parser.add_argument( "--fs", type=int, default=24000, # BUG: another fs in feats_extract_conf help="sample rate", ) for class_choices in cls.class_choices_list: # Append --<name> and --<name>_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(group)
[docs] @classmethod @typechecked def build_collate_fn(cls, args: argparse.Namespace, train: bool) -> Callable[ [Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]], ]: return CommonCollateFn( float_pad_value=0.0, int_pad_value=0, not_sequence=["spembs", "sids", "lids"], )
[docs] @classmethod @typechecked def build_preprocess_fn( cls, args: argparse.Namespace, train: bool ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: if args.use_preprocessor: retval = SVSPreprocessor( train=train, token_type=args.token_type, token_list=args.token_list, bpemodel=args.bpemodel, non_linguistic_symbols=args.non_linguistic_symbols, text_cleaner=args.cleaner, g2p_type=args.g2p, fs=args.fs, hop_length=args.feats_extract_conf["hop_length"], ) else: retval = None return retval
# TODO(Yuning): check new names
[docs] @classmethod def required_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: retval = ("text", "singing", "score", "label") else: # Inference mode retval = ("text", "score", "label") return retval
[docs] @classmethod def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: retval = ( "spembs", "durations", "pitch", "energy", "sids", "lids", "feats", "ying", ) else: # Inference mode retval = ("spembs", "singing", "pitch", "durations", "sids", "lids") return retval
[docs] @classmethod @typechecked def build_model(cls, args: argparse.Namespace) -> ESPnetGANSVSModel: if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # "args" is saved as it is in a yaml file by BaseTask.main(). # Overwriting token_list to keep it as "portable". args.token_list = token_list.copy() elif isinstance(args.token_list, (tuple, list)): token_list = args.token_list.copy() else: raise RuntimeError("token_list must be str or dict") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") # 1. feats_extract if args.odim is None: # Extract features in the model feats_extract_class = feats_extractor_choices.get_class(args.feats_extract) feats_extract = feats_extract_class(**args.feats_extract_conf) odim = feats_extract.output_size() else: # Give features from data-loader args.feats_extract = None args.feats_extract_conf = None feats_extract = None odim = args.odim # 1. ssl postfrontend if args.input_size is None and args.postfrontend is not None: # Extract features in the model postfrontend_class = postfrontend_choices.get_class(args.postfrontend) postfrontend = postfrontend_class( **args.postfrontend_conf, input_fs=args.svs_conf["sampling_rate"] ) input_size = postfrontend.output_size() else: # Give features from data-loader args.postfrontend = None args.postfrontend_conf = {} postfrontend = None input_size = args.input_size # 2. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 3. SVS svs_class = svs_choices.get_class(args.svs) args.svs_conf["generator_params"].update({"hubert_channels": input_size}) svs = svs_class(idim=vocab_size, odim=odim, **args.svs_conf) # 4. Extra components score_feats_extract = None pitch_extract = None ying_extract = None energy_extract = None pitch_normalize = None energy_normalize = None logging.info(f"args:{args}") if getattr(args, "score_feats_extract", None) is not None: score_feats_extract_class = score_feats_extractor_choices.get_class( args.score_feats_extract ) score_feats_extract = score_feats_extract_class( **args.score_feats_extract_conf ) if getattr(args, "pitch_extract", None) is not None: pitch_extract_class = pitch_extractor_choices.get_class( args.pitch_extract, ) pitch_extract = pitch_extract_class( **args.pitch_extract_conf, ) if getattr(args, "ying_extract", None) is not None: ying_extract_class = ying_extractor_choices.get_class( args.ying_extract, ) ying_extract = ying_extract_class( **args.ying_extract_conf, ) if getattr(args, "energy_extract", None) is not None: energy_extract_class = energy_extractor_choices.get_class( args.energy_extract, ) energy_extract = energy_extract_class( **args.energy_extract_conf, ) if getattr(args, "pitch_normalize", None) is not None: pitch_normalize_class = pitch_normalize_choices.get_class( args.pitch_normalize, ) pitch_normalize = pitch_normalize_class( **args.pitch_normalize_conf, ) if getattr(args, "energy_normalize", None) is not None: energy_normalize_class = energy_normalize_choices.get_class( args.energy_normalize, ) energy_normalize = energy_normalize_class( **args.energy_normalize_conf, ) # 5. Build model model = ESPnetGANSVSModel( postfrontend=postfrontend, text_extract=score_feats_extract, feats_extract=feats_extract, score_feats_extract=score_feats_extract, label_extract=score_feats_extract, pitch_extract=pitch_extract, ying_extract=ying_extract, duration_extract=score_feats_extract, energy_extract=energy_extract, normalize=normalize, pitch_normalize=pitch_normalize, energy_normalize=energy_normalize, svs=svs, **args.model_conf, ) return model
[docs] @classmethod def build_optimizers( cls, args: argparse.Namespace, model: ESPnetGANSVSModel, ) -> List[torch.optim.Optimizer]: # check assert hasattr(model.svs, "generator") assert hasattr(model.svs, "discriminator") # define generator optimizer optim_g_class = optim_classes.get(args.optim) if optim_g_class is None: raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") if args.sharded_ddp: try: import fairscale except ImportError: raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'") optim_g = fairscale.optim.oss.OSS( params=model.svs.generator.parameters(), optim=optim_g_class, **args.optim_conf, ) else: optim_g = optim_g_class( model.svs.generator.parameters(), **args.optim_conf, ) optimizers = [optim_g] # define discriminator optimizer optim_d_class = optim_classes.get(args.optim2) if optim_d_class is None: raise ValueError(f"must be one of {list(optim_classes)}: {args.optim2}") if args.sharded_ddp: try: import fairscale except ImportError: raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'") optim_d = fairscale.optim.oss.OSS( params=model.svs.discriminator.parameters(), optim=optim_d_class, **args.optim2_conf, ) else: optim_d = optim_d_class( model.svs.discriminator.parameters(), **args.optim2_conf, ) optimizers += [optim_d] return optimizers