Source code for espnet.bin.tts_train

#!/usr/bin/env python3

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Text-to-speech model training script."""

import logging
import os
import random
import subprocess
import sys

import configargparse
import numpy as np

from espnet import __version__
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.cli_utils import strtobool
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES


# NOTE: you need this func to generate our sphinx doc
[docs]def get_parser(): """Get parser of training arguments.""" parser = configargparse.ArgumentParser( description="Train a new text-to-speech (TTS) model on one CPU, " "one or multiple GPUs", config_file_parser_class=configargparse.YAMLConfigFileParser, formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) # general configuration parser.add("--config", is_config_file=True, help="config file path") parser.add( "--config2", is_config_file=True, help="second config file path that overwrites the settings in `--config`.", ) parser.add( "--config3", is_config_file=True, help="third config file path that overwrites " "the settings in `--config` and `--config2`.", ) parser.add_argument( "--ngpu", default=None, type=int, help="Number of GPUs. If not given, use all visible devices", ) parser.add_argument( "--backend", default="pytorch", type=str, choices=["chainer", "pytorch"], help="Backend library", ) parser.add_argument("--outdir", type=str, required=True, help="Output directory") parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") parser.add_argument("--seed", default=1, type=int, help="Random seed") parser.add_argument( "--resume", "-r", default="", type=str, nargs="?", help="Resume the training from snapshot", ) parser.add_argument( "--minibatches", "-N", type=int, default="-1", help="Process only N minibatches (for debug)", ) parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") parser.add_argument( "--tensorboard-dir", default=None, type=str, nargs="?", help="Tensorboard log directory path", ) parser.add_argument( "--eval-interval-epochs", default=1, type=int, help="Evaluation interval epochs" ) parser.add_argument( "--save-interval-epochs", default=1, type=int, help="Save interval epochs" ) parser.add_argument( "--report-interval-iters", default=100, type=int, help="Report interval iterations", ) # task related parser.add_argument( "--train-json", type=str, required=True, help="Filename of training json" ) parser.add_argument( "--valid-json", type=str, required=True, help="Filename of validation json" ) # network architecture parser.add_argument( "--model-module", type=str, default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2", help="model defined module", ) # minibatch related parser.add_argument( "--sortagrad", default=0, type=int, nargs="?", help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", ) parser.add_argument( "--batch-sort-key", default="shuffle", type=str, choices=["shuffle", "output", "input"], nargs="?", help='Batch sorting key. "shuffle" only work with --batch-count "seq".', ) parser.add_argument( "--batch-count", default="auto", choices=BATCH_COUNT_CHOICES, help="How to count batch_size. " "The default (auto) will find how to count by args.", ) parser.add_argument( "--batch-size", "--batch-seqs", "-b", default=0, type=int, help="Maximum seqs in a minibatch (0 to disable)", ) parser.add_argument( "--batch-bins", default=0, type=int, help="Maximum bins in a minibatch (0 to disable)", ) parser.add_argument( "--batch-frames-in", default=0, type=int, help="Maximum input frames in a minibatch (0 to disable)", ) parser.add_argument( "--batch-frames-out", default=0, type=int, help="Maximum output frames in a minibatch (0 to disable)", ) parser.add_argument( "--batch-frames-inout", default=0, type=int, help="Maximum input+output frames in a minibatch (0 to disable)", ) parser.add_argument( "--maxlen-in", "--batch-seq-maxlen-in", default=100, type=int, metavar="ML", help="When --batch-count=seq, " "batch size is reduced if the input sequence length > ML.", ) parser.add_argument( "--maxlen-out", "--batch-seq-maxlen-out", default=200, type=int, metavar="ML", help="When --batch-count=seq, " "batch size is reduced if the output sequence length > ML", ) parser.add_argument( "--num-iter-processes", default=0, type=int, help="Number of processes of iterator", ) parser.add_argument( "--preprocess-conf", type=str, default=None, help="The configuration file for the pre-processing", ) parser.add_argument( "--use-speaker-embedding", default=False, type=strtobool, help="Whether to use speaker embedding", ) parser.add_argument( "--use-second-target", default=False, type=strtobool, help="Whether to use second target", ) # optimization related parser.add_argument( "--opt", default="adam", type=str, choices=["adam", "noam"], help="Optimizer" ) parser.add_argument( "--accum-grad", default=1, type=int, help="Number of gradient accumuration" ) parser.add_argument( "--lr", default=1e-3, type=float, help="Learning rate for optimizer" ) parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer") parser.add_argument( "--weight-decay", default=1e-6, type=float, help="Weight decay coefficient for optimizer", ) parser.add_argument( "--epochs", "-e", default=30, type=int, help="Number of maximum epochs" ) parser.add_argument( "--early-stop-criterion", default="validation/main/loss", type=str, nargs="?", help="Value to monitor to trigger an early stopping of the training", ) parser.add_argument( "--patience", default=3, type=int, nargs="?", help="Number of epochs to wait " "without improvement before stopping the training", ) parser.add_argument( "--grad-clip", default=1, type=float, help="Gradient norm threshold to clip" ) parser.add_argument( "--num-save-attention", default=5, type=int, help="Number of samples of attention to be saved", ) parser.add_argument( "--keep-all-data-on-mem", default=False, type=strtobool, help="Whether to keep all data on memory", ) # finetuning related parser.add_argument( "--enc-init", default=None, type=str, help="Pre-trained TTS model path to initialize encoder.", ) parser.add_argument( "--enc-init-mods", default="enc.", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="List of encoder modules to initialize, separated by a comma.", ) parser.add_argument( "--dec-init", default=None, type=str, help="Pre-trained TTS model path to initialize decoder.", ) parser.add_argument( "--dec-init-mods", default="dec.", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="List of decoder modules to initialize, separated by a comma.", ) parser.add_argument( "--freeze-mods", default=None, type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="List of modules to freeze (not to train), separated by a comma.", ) return parser
[docs]def main(cmd_args): """Run training.""" parser = get_parser() args, _ = parser.parse_known_args(cmd_args) from espnet.utils.dynamic_import import dynamic_import model_class = dynamic_import(args.model_module) assert issubclass(model_class, TTSInterface) model_class.add_arguments(parser) args = parser.parse_args(cmd_args) # add version info in args args.version = __version__ # logging info if args.verbose > 0: logging.basicConfig( level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # If --ngpu is not given, # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices # 2. if nvidia-smi exists, use all devices # 3. else ngpu=0 if args.ngpu is None: cvd = os.environ.get("CUDA_VISIBLE_DEVICES") if cvd is not None: ngpu = len(cvd.split(",")) else: logging.warning("CUDA_VISIBLE_DEVICES is not set.") try: p = subprocess.run( ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) except (subprocess.CalledProcessError, FileNotFoundError): ngpu = 0 else: ngpu = len(p.stderr.decode().split("\n")) - 1 args.ngpu = ngpu else: ngpu = args.ngpu logging.info(f"ngpu: {ngpu}") # set random seed logging.info("random seed = %d" % args.seed) random.seed(args.seed) np.random.seed(args.seed) if args.backend == "pytorch": from espnet.tts.pytorch_backend.tts import train train(args) else: raise NotImplementedError("Only pytorch is supported.")
if __name__ == "__main__": main(sys.argv[1:])