Fine-Tuning VISinger 2 for Singing Voice Synthesis on a New Dataset
Fine-Tuning VISinger 2 for Singing Voice Synthesis on a New Dataset
This Jupyter notebook provides a step-by-step guide on using the ESPnetEZ module to fine-tune a pretrained VISinger 2 model. In this demonstration, we will use ESPnet's singing corpora on Hugging Face for the SVS task. This demo covers data preparation, fine-tuning, inference, and evaluation.
Overview
- Task: Singing Voice Synthesis
- Dataset: ACE-KiSing
- Model: VISinger 2 model trained on ACE-Opencpop - espnet/aceopencpop_svs_visinger2_40singer_pretrain
License Reminder
Before proceeding, please note that the datasets and models used in this tutorial come with specific licensing terms:
ACE-KiSing Dataset: The ACE-KiSing dataset is distributed under the Creative Commons Attribution Non Commercial 4.0 (CC BY-NC 4.0) license. This means you are free to use, share, and adapt the data, but only for non-commercial purposes. Any commercial use of this dataset is prohibited without explicit permission from the dataset creators.
Pretrained VISinger 2 Model: The VISinger 2 model used in this tutorial is distributed under the Creative Commons Attribution 4.0 (CC BY 4.0) license. This means you can use, modify, and redistribute the model, even for commercial purposes, as long as proper credit is given to the creators.
Prepare Environment
Clone ESPnet's Repository
!git clone https://github.com/espnet/espnet.git
Install ESPnet and Dependencies
!cd espnet && pip install .
!pip install espnet_model_zoo tensorboard
!pip install datasets
Import ESPnetEZ
import espnetez as ez
Data Preparation
We will use ESPnet's ACE-KiSing dataset available on Hugging Face: espnet/ace-kising-segments. Let's begin by loading the dataset, resampling the audio to match the model's requirements, and wrapping it using ESPnetEZtaset.
Load dataset
To start, load the ACE-KiSing dataset using the datasets library.
from datasets import load_dataset
dataset = load_dataset("espnet/ace-kising-segments", cache_dir="cache")
train_dataset = dataset["train"]
valid_dataset = dataset["validation"]
test_dataset = dataset["test"]
Display the first two instances from the training dataset
it = iter(train_dataset)
next(it), next(it)
Resample Audio
Resample the audio to a 44.1kHz sampling rate to match the requirements of the pretrained model. For more details, refer to the model's SVS configuration.
from datasets import Audio
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=44100))
valid_dataset = valid_dataset.cast_column("audio", Audio(sampling_rate=44100))
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=44100))
Define Dataset Information
# Map from speaker names of the KiSing dataset to speaker ids matched with the pretrained model
singer2sid = {
"barber": 3,
"blanca": 30,
"changge": 5,
"chuci": 19,
"chuming": 4,
"crimson": 1,
"david": 28,
"ghost": 27,
"growl": 25,
"hiragi-yuki": 22,
"huolian": 13,
"kuro": 2,
"lien": 29,
"liyuan": 9,
"luanming": 21,
"luotianyi": 31,
"namine": 8,
"orange": 12,
"original": 32,
"qifu": 16,
"qili": 15,
"qixuan": 7,
"quehe": 6,
"ranhuhu": 11,
"steel": 26,
"tangerine": 23,
"tarara": 20,
"tuyuan": 24,
"wenli": 10,
"xiaomo": 17,
"xiaoye": 14,
"yanhe": 33,
"yuezhengling": 34,
"yunhao": 18,
}
import numpy as np
# Define data mapping functions
data_info = {
"singing": lambda d: d['audio']['array'].astype(np.float32),
"score": lambda d: (d['tempo'], list(zip(*[d[key] for key in ('note_start_times', 'note_end_times', 'note_lyrics', 'note_midi', 'note_phns')]))),
"text": lambda d: d['transcription'],
"label": lambda d: (np.array(list(zip(*[d[key] for key in ('phn_start_times', 'phn_end_times')]))), d['phns']),
"sids": lambda d: np.array([singer2sid[d['singer']]]),
}
Load as ESPnetEZ Dataset
train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
valid_dataset = ez.dataset.ESPnetEZDataset(valid_dataset, data_info=data_info)
Fine-Tuning
Download Pretrained VISinger 2 Model
We'll use ESPnet's model zoo to download the pretrained VISinger 2 model from the ACE-Opencpop dataset.
from espnet_model_zoo.downloader import ModelDownloader
PRETRAIN_MODEL = "espnet/aceopencpop_svs_visinger2_40singer_pretrain"
d = ModelDownloader()
pretrain_downloaded = d.download_and_unpack(PRETRAIN_MODEL)
Configure Fine-Tuning
Load the pretrained model's configuration and set it up for fine-tuning.
TASK = "gan_svs"
pretrain_config = ez.config.from_yaml(TASK, pretrain_downloaded["train_config"])
# Update the configuration with the downloaded model file path
pretrain_config["model_file"] = pretrain_downloaded["model_file"]
# Modify configuration for fine-tuning
finetune_config = pretrain_config.copy()
finetune_config["batch_size"] = 1
finetune_config["num_workers"] = 1
finetune_config["max_epoch"] = 40
finetune_config["save_lora_only"] = False
finetune_config["num_iters_per_epoch"] = None
finetune_config["use_preprocessor"] = True # Use SVS preprocessor for loading the dataset
# Clear the original local file paths in the config
finetune_config["train_data_path_and_name_and_type"] = []
finetune_config["valid_data_path_and_name_and_type"] = []
finetune_config["train_shape_file"] = []
finetune_config["valid_shape_file"] = []
finetune_config["output_dir"] = None
Initialize Trainer
Define the trainer for the fine-tuning process.
dataset_name = "ace-kising"
EXP_DIR = f"exp/finetune_{dataset_name}_{TASK}"
STATS_DIR = f"exp/stats_{dataset_name}"
trainer = ez.Trainer(
task=TASK,
train_config=finetune_config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
data_info=data_info,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1,
)
Collect Statistics
Before training, we need to collect data statistics (e.g., normalization stats).
# Temporarily set to None, as we need to collect stats first
trainer.train_config.normalize = None
trainer.train_config.pitch_normalize = None
trainer.train_config.energy_normalize = None
# Collect stats
trainer.collect_stats()
# Restore normalization configs with collected stats
trainer.train_config.normalize = finetune_config["normalize"]
trainer.train_config.pitch_normalize = finetune_config["pitch_normalize"]
trainer.train_config.normalize_conf["stats_file"] = f"{STATS_DIR}/train/feats_stats.npz"
trainer.train_config.pitch_normalize_conf["stats_file"] = f"{STATS_DIR}/train/pitch_stats.npz"
Start Training
Now, let's start the fine-tuning process.
trainer.train()
Inference
Once the model is fine-tuned, you can generate synthesized singing voice using the test dataset.
Set Up the Model for Inference
Load the trained model and prepare for inference
from espnet2.bin.svs_inference import SingingGenerate
ckpt_name = "train.total_count.ave_10best"
m = SingingGenerate(
f"{EXP_DIR}/config.yaml",
f"{EXP_DIR}/{ckpt_name}.pth",
)
m.model.eval()
Wrap dataset with ESPnetEZDataset
test_dataset = ez.dataset.ESPnetEZDataset(test_dataset, data_info=data_info)
Run inference with test data
Here, we will demonstrate how to perform inference using a single data instance from the test dataset.
# Get the first instance
(key, batch) = next(iter(test_dataset))
# Remove unnecessary data from batch
batch.pop("singing")
batch.pop("text")
sids = batch.pop("sids")
# Generate the output
output_dict = m(batch, sids=sids)
Save the generated singing voice to a WAV file.
import soundfile as sf
sf.write(
f"{EXP_DIR}/{key}.wav",
output_dict["wav"].cpu().numpy(),
44100,
"PCM_16",
)
Evaluation
In this section, we will assess the model's performance based on speaker similarity, Mel-cepstral distortion, the root mean square error (RMSE) of the fundamental frequency (f0), and the Pearson correlation coefficient for f0.
from speech_evaluation import speaker_metric, speaker_model_setup, mcd_f0
from datasets import load_dataset, Audio
import soundfile as sf
from pathlib import Path
ckpt_name = "train.total_count.ave_10best"
sr = 44100
EXP_DIR = Path(f"exp/finetune_{dataset_name}_{TASK}")
inference_dir = EXP_DIR / f"inference_test_{ckpt_name}"
(inference_dir / "wav").mkdir(exist_ok=True, parents=True)
test_dataset = load_dataset(
"espnet/ace-kising-segments", cache_dir="cache", split="test"
)
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=sr))
test_dataset = ez.dataset.ESPnetEZDataset(test_dataset, data_info=data_info)
loader = iter(test_dataset)
model = speaker_model_setup()
spk_similarities = []
mcd_f0s = []
f0rmses = []
f0corrs = []
for key, batch in loader:
gt = batch.pop("singing")
sids = batch.pop("sids")
batch.pop("text")
output_dict = m(batch, sids=sids)
pred = output_dict["wav"].cpu().numpy()
sf.write(
f"{inference_dir}/wav/{key}.wav",
pred,
sr,
"PCM_16",
)
ret = speaker_metric(model, pred, gt, sr)
with open(f"{inference_dir}/spk_similarity", "a") as f:
f.write(f"{ret['spk_similarity']}\n")
spk_similarities.append(ret["spk_similarity"])
ret = mcd_f0(pred, gt, sr, 1, 800, dtw=True)
with open(f"{inference_dir}/mcd_f0", "a") as f:
f.write(f"{ret['mcd']}\n")
with open(f"{inference_dir}/f0rmse", "a") as f:
f.write(f"{ret['f0rmse']}\n")
with open(f"{inference_dir}/f0corr", "a") as f:
f.write(f"{ret['f0corr']}\n")
mcd_f0s.append(ret["mcd"])
f0rmses.append(ret["f0rmse"])
f0corrs.append(ret["f0corr"])
print("Averaged speaker similarity:", sum(spk_similarities) / len(spk_similarities))
print("Averaged MCD:", sum(mcd_f0s) / len(mcd_f0s))
print("Averaged F0 RMSE:", sum(f0rmses) / len(f0rmses))
print("Averaged F0 Corr:", sum(f0corrs) / len(f0corrs))
References
[1] S. Someki, K. Choi, S. Arora, W. Chen, S. Cornell, J. Han, Y. Peng, J. Shi, V. Srivastav, and S. Watanabe, “ESPnet-EZ: Python-only ESPnet for Easy Fine-tuning and Integration,” arXiv preprint arXiv:2409.09506, 2024.
[2] J. Shi, Y. Lin, X. Bai, K. Zhang, Y. Wu, Y. Tang, Y. Yu, Q. Jin, and S. Watanabe, “Singing Voice Data Scaling-up: An Introduction to ACE-Opencpop and ACE-KiSing,” arXiv preprint arXiv:2401.17619, 2024.