Fine-Tuning VITS for Text-to-Speech Synthesis on a New Dataset
Fine-Tuning VITS for Text-to-Speech Synthesis on a New Dataset
In this tutorial, we will guide you through the process of performing text-to-speech (TTS) synthesis by fine-tuning the VITS model on the VCTK dataset. This demo covers data preparation from dump files, model fine-tuning, inference, and evaluation.
Overview
- Task: Text-to-Speech (TTS)
- Dataset: VCTK
- Model: VITS - espnet/kan-bayashi_libritts_xvector_vits
License Reminder
Before proceeding, please note that the dataset and model used in this tutorial come with specific licensing terms:
- VCTK Corpus: Licensed under the Open Data Commons Attribution License (ODC-By) v1.0.
- Model: The pretrained VITS model is under the Creative Commons Attribution 4.0 License.
Prepare Environment
Clone ESPnet's Repository
!git clone https://github.com/espnet/espnet.git
Install ESPnet and Dependencies
!cd espnet && pip install .
!pip install espnet_model_zoo tensorboard
!pip install datasets
Import ESPnetEZ
import espnetez as ez
Data Preparation
In this tutorial, we will use ESPnet-generated dump files as our inputs. Set up the directory where your processed dump folder is stored.
DUMP_DIR = f"dump"
data_info = {
"speech": ["wav.scp", "sound"],
"text": ["text", "text"],
}
Fine-Tuning
Download Pretrained VITS Model
We'll use ESPnet's model zoo to download the pretrained VITS model from the LibriTTS corpus.
from espnet_model_zoo.downloader import ModelDownloader
PRETRAIN_MODEL = "espnet/kan-bayashi_libritts_xvector_vits"
d = ModelDownloader()
pretrain_downloaded = d.download_and_unpack(PRETRAIN_MODEL)
Configure Fine-Tuning
Load the pretrained model's configuration and set it up for fine-tuning.
TASK = "gan_tts"
pretrain_config = ez.config.from_yaml(TASK, pretrain_downloaded["train_config"])
# Update the configuration with the downloaded model file path
pretrain_config["model_file"] = pretrain_downloaded["model_file"]
# Modify configuration for fine-tuning
finetune_config = pretrain_config.copy()
finetune_config["batch_size"] = 1
finetune_config["num_workers"] = 1
finetune_config["max_epoch"] = 100
finetune_config["batch_bins"] = 500000
finetune_config["num_iters_per_epoch"] = None
finetune_config["generator_first"] = True
# Disable distributed training
finetune_config["distributed"] = False
finetune_config["multiprocessing_distributed"] = False
finetune_config["dist_world_size"] = None
finetune_config["dist_rank"] = None
finetune_config["local_rank"] = None
finetune_config["dist_master_addr"] = None
finetune_config["dist_master_port"] = None
finetune_config["dist_launcher"] = None
Initialize Trainer
Define the trainer for the fine-tuning process.
DATASET_NAME = "vctk"
EXP_DIR = f"./exp/finetune_{TASK}_{DATASET_NAME}"
STATS_DIR = f"./exp/stats_{DATASET_NAME}"
ngpu = 1
trainer = ez.Trainer(
task=TASK,
train_config=finetune_config,
train_dump_dir=f"{DUMP_DIR}/raw/tr_no_dev",
valid_dump_dir=f"{DUMP_DIR}/raw/dev",
data_info=data_info,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=ngpu,
)
# Add the xvector paths to the configuration
trainer.train_config.train_data_path_and_name_and_type += [
[f"{DUMP_DIR}/xvector/tr_no_dev/xvector.scp", "spembs", "kaldi_ark"],
]
trainer.train_config.valid_data_path_and_name_and_type += [
[f"{DUMP_DIR}/xvector/dev/xvector.scp", "spembs", "kaldi_ark"],
]
Collect Statistics
Before training, we need to collect data statistics (e.g., normalization stats).
# Temporarily set to None, as we need to collect stats first
trainer.train_config.normalize = None
trainer.train_config.pitch_normalize = None
trainer.train_config.energy_normalize = None
# Collect stats
trainer.collect_stats()
# Restore normalization configs with collected stats
trainer.train_config.write_collected_feats = False
if finetune_config["normalize"] is not None:
trainer.train_config.normalize = finetune_config["normalize"]
trainer.train_config.normalize_conf["stats_file"] = (
f"{STATS_DIR}/train/feats_stats.npz"
)
if finetune_config["pitch_normalize"] is not None:
trainer.train_config.pitch_normalize = finetune_config["pitch_normalize"]
trainer.train_config.pitch_normalize_conf["stats_file"] = (
f"{STATS_DIR}/train/pitch_stats.npz"
)
if finetune_config["energy_normalize"] is not None:
trainer.train_config.energy_normalize = finetune_config["energy_normalize"]
trainer.train_config.energy_normalize_conf["stats_file"] = (
f"{STATS_DIR}/train/energy_stats.npz"
)
Start Training
Now, let's start the fine-tuning process.
trainer.train()
Inference
When training is done, we can use the inference API to synthesize audio from the test set.
from espnet2.bin.tts_inference import inference
ckpt_name = "train.total_count.ave_10best"
inference_folder = f"{EXP_DIR}/inference_{ckpt_name}"
model_file = f"{EXP_DIR}/{ckpt_name}.pth"
inference(
output_dir=inference_folder,
batch_size=1,
dtype="float32",
ngpu=0,
seed=0,
num_workers=1,
log_level="INFO",
data_path_and_name_and_type=[
(f"{DUMP_DIR}/raw/eval1/text", "text", "text"),
(f"{DUMP_DIR}/raw/eval1/wav.scp", "speech", "sound"),
(f"{DUMP_DIR}/xvector/eval1/xvector.scp", "spembs", "kaldi_ark"),
],
key_file=None,
train_config=f"{EXP_DIR}/config.yaml",
model_file=model_file,
model_tag=None,
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_teacher_forcing=False,
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
noise_scale=0.667,
noise_scale_dur=0.8,
always_fix_seed=False,
allow_variable_data_keys=False,
vocoder_config=None,
vocoder_file=None,
vocoder_tag=None,
)
Evaluation
In this section, we will assess the model's performance based on speaker similarity, Mel-cepstral distortion, the root mean square error (RMSE) of the fundamental frequency (f0), and the Pearson correlation coefficient for f0.
import soundfile as sf
from speech_evaluation import speaker_metric, speaker_model_setup, mcd_f0
gt_wav_scp = f"{DUMP_DIR}/raw/eval1/wav.scp"
model = speaker_model_setup()
spk_similarities = []
mcd_f0s = []
f0rmses = []
f0corrs = []
with open(gt_wav_scp, "r") as f:
for line in f:
key, path = line.strip().split()
gt, sr = sf.read(path)
pred, sr = sf.read(f"{inference_folder}/wav/{key}.wav")
ret = speaker_metric(model, pred, gt, sr)
with open(f"{inference_folder}/spk_similarity", "a") as f:
f.write(f"{ret['spk_similarity']}\n")
spk_similarities.append(ret["spk_similarity"])
ret = mcd_f0(pred, gt, sr, 1, 800, dtw=True)
with open(f"{inference_folder}/mcd_f0", "a") as f:
f.write(f"{ret['mcd']}\n")
with open(f"{inference_folder}/f0rmse", "a") as f:
f.write(f"{ret['f0rmse']}\n")
with open(f"{inference_folder}/f0corr", "a") as f:
f.write(f"{ret['f0corr']}\n")
mcd_f0s.append(ret["mcd"])
f0rmses.append(ret["f0rmse"])
f0corrs.append(ret["f0corr"])
print("Averaged speaker similarity:", sum(spk_similarities) / len(spk_similarities))
print("Averaged MCD:", sum(mcd_f0s) / len(mcd_f0s))
print("Averaged F0 RMSE:", sum(f0rmses) / len(f0rmses))
print("Averaged F0 Corr:", sum(f0corrs) / len(f0corrs))
References
[1] S. Someki, K. Choi, S. Arora, W. Chen, S. Cornell, J. Han, Y. Peng, J. Shi, V. Srivastav, and S. Watanabe, “ESPnet-EZ: Python-only ESPnet for Easy Fine-tuning and Integration,” arXiv preprint arXiv:2409.09506, 2024.
[2] C. Veaux, J. Yamagishi, and K. MacDonald, “CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit,” University of Edinburgh, The Centre for Speech Technology Research (CSTR), 2017. [Sound]. https://doi.org/10.7488/ds/1994.