OWSM finetuning with custom dataset
OWSM finetuning with custom dataset
This Jupyter notebook provides a step-by-step guide on using the ESPnetEZ module to finetune owsm model. In this demonstration, we will leverage the custom dataset to finetune an OWSM model for ASR task.
Author: Masao Someki @Masao-Someki
Data Preparation
For this tutorial, we assume that we have the custom dataset with 654 audio with the following directory structure:
audio
├── 001 [420 files]
└── 002 [234 files]
transcription
└── owsm_v3.1
├── 001.csv
└── 002.csv
The csv files contains the audio path, text, and text_ctc data in Japanese. For example, the csv constains the following data:
audio/001/00014.wav,しゃべるたびに追いかけてくるんですけど,なんかしゃべるたびにおいかけてくるんですけど
audio/001/00015.wav,え、どうしよう,えどうしよう
import os
from glob import glob
import numpy as np
import librosa
import torch
from espnet2.bin.s2t_inference import Speech2Text
from espnet2.layers.create_adapter_fn import create_lora_adapter
import espnetez as ez
# Define hyper parameters
DUMP_DIR = f"./dump"
CSV_DIR = f"./transcription"
EXP_DIR = f"./exp/finetune"
STATS_DIR = f"./exp/stats_finetune"
FINETUNE_MODEL = "espnet/owsm_v3.1_ebf"
LORA_TARGET = [
"w_1", "w_2", "merge_proj"
]
LANGUAGE = "jpn"
Setup training configs and model
Since we are going to finetune an OWSM model for ASR task, we will use the tokenizer and TokenIDConverter of the OWSM model. We will also use the training config as the default parameter sets, and update them with the finetuning configuration.
In this demo, we will apply Lora adapter to the model for parameter efficient fine-tuning.
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
category_sym=f"<{LANGUAGE}>",
beam_size=10,
) # Load model to extract configs.
pretrain_config = vars(pretrained_model.s2t_train_args)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter
del pretrained_model
# For the configuration, please refer to the last cell in this notebook.
finetune_config = ez.config.update_finetune_config(
's2t',
pretrain_config,
f"finetune_with_lora.yaml"
)
# When you don't use yaml file, you can load finetune_config in the following way.
# task_class = ez.task.get_ez_task("s2t")
# default_config = task_class.get_default_config()
# training_config = default_config.update(your_config_in_dict)
# define model loading function
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def build_model_fn(args):
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
category_sym=f"<{LANGUAGE}>",
beam_size=10,
)
model = pretrained_model.s2t_model
model.train()
print(f'Trainable parameters: {count_parameters(model)}')
# apply lora
create_lora_adapter(model, target_modules=LORA_TARGET)
print(f'Trainable parameters after LORA: {count_parameters(model)}')
return model
Wrap with ESPnetEasyDataset
Before initiating the training process, it is crucial to adapt the dataset to the ESPnet format. The dataset class should output tokenized text and audio files in np.array
format.
Then let's define the custom dataset class. The owsm finetuning requires audio
, text
, text_prev
and text_ctc
data. You can use your custom-defined dataset, huggingface datasets
library, or lhotse
library, or any other dataloader that you want to use.
When you try to use custom-defined dataset, you should define the data_info
dictionary. It defines the mapping between the output of your model and the input of ESPnet models.
Note:
- Currently we do not support the custom dataloader that feeds processed feature.
# custom dataset class
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data_list):
# data_list is a list of tuples (audio_path, text, text_ctc)
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self._parse_single_data(self.data[idx])
def _parse_single_data(self, d):
text = f"<{LANGUAGE}><asr><notimestamps> {d['transcript']}"
return {
"audio_path": d["audio_path"],
"text": text,
"text_prev": "<na>",
"text_ctc": d['text_ctc'],
}
data_list = []
for csv_file in sorted(glob(os.path.join(CSV_DIR, "*.csv"))):
with open(csv_file, "r", encoding="utf-8") as f:
data_list += f.readlines()[1:] # skip header
validation_examples = 20
train_dataset = CustomDataset(data_list[:-validation_examples])
valid_dataset = CustomDataset(data_list[-validation_examples:])
def tokenize(text):
return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
# The output of CustomDatasetInstance[idx] will converted to np.array
# with the functions defined in the data_info dictionary.
data_info = {
"speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
"text": lambda d: tokenize(d["text"]),
"text_prev": lambda d: tokenize(d["text_prev"]),
"text_ctc": lambda d: tokenize(d["text_ctc"]),
}
Or if you want to use datasets
library from huggingface or lhotse
library:
# Datasets library
from datasets import load_dataset, Audio
train_dataset = load_dataset("audiofolder", data_dir=f"/path/to/huggingface_dataset", split=f'train[:-{validation_examples}]')
valid_dataset = load_dataset("audiofolder", data_dir=f"/path/to/huggingface_dataset", split=f'train[-{validation_examples}:]')
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
valid_dataset = valid_dataset.cast_column("audio", Audio(sampling_rate=16000))
data_info = {
"speech": lambda d: d['audio']['array'],
"text": lambda d: tokenize(f"<{LANGUAGE}><asr><notimestamps> {d['transcript']}"),
"text_prev": lambda d: tokenize("<na>"),
"text_ctc": lambda d: tokenize(d["text_ctc"]),
}
# Or lhotse library. The following code is from the official document.
from pathlib import Path
from lhotse import CutSet
from lhotse.recipes import download_librispeech, prepare_librispeech
def load_audio(audio_path):
y, _ = librosa.load(audio_path, sr=16000)
return y
root_dir = Path("data")
tmp_dir = Path("tmp")
tmp_dir.mkdir(exist_ok=True)
num_jobs = os.cpu_count() - 1
libri_variant = "mini_librispeech"
libri_root = download_librispeech(root_dir, dataset_parts=libri_variant)
libri = prepare_librispeech(
libri_root, dataset_parts=libri_variant, output_dir=root_dir, num_jobs=num_jobs
)
train_dataset = CutSet.from_manifests(**libri["train-clean-5"])
valid_dataset = CutSet.from_manifests(**libri["dev-clean-2"])
data_info = {
"speech": lambda d: load_audio(d.recording.sources[0].source),
"text": lambda d: tokenize(f"<{LANGUAGE}><asr><notimestamps> {d.supervisions[0].text}"),
"text_prev": lambda d: tokenize("<na>"),
"text_ctc": lambda d: tokenize(d.supervisions[0].text),
}
And finally you need to wrap your custom dataset with ESPnetEZDataset.
# Convert into ESPnet-EZ dataset format
train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
valid_dataset = ez.dataset.ESPnetEZDataset(valid_dataset, data_info=data_info)
Training
While the configuration remains consistent with other notebooks, the instantiation arguments for the Trainer class differ in this case. As we have not generated dump files, we can disregard arguments related to dump files and directly provide the train/valid dataset classes.
trainer = Trainer(
...
train_dataset=your_train_dataset_instance,
train_dataset=your_valid_dataset_instance,
...
)
trainer = ez.Trainer(
task='s2t',
train_config=finetune_config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
build_model_fn=build_model_fn, # provide the pre-trained model
data_info=data_info,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1
)
trainer.collect_stats()
trainer.train()
Inference
When training is done, we can use the inference API to generate the transcription, but don't forget to apply lora before loading the model!
DEVICE = "cuda"
model = Speech2Text.from_pretrained(
"espnet/owsm_v3.1_ebf",
category_sym="<jpn>",
beam_size=10,
device=DEVICE
)
create_lora_adapter(model.s2t_model, target_modules=LORA_TARGET)
model.s2t_model.eval()
d = torch.load("./exp/finetune/1epoch.pth")
model.s2t_model.load_state_dict(d)
Results
As a result, the finetuned owsm-v3.1 could successfully transcribe the audio files.
Example
- correct transcription: ダンスでこの世界に彩りを。
- before finetune: 出してこの時間二のどりを。
- after finetune: ダンスでこの世界に彩りを。
Finetune configuration
# LoRA finetune related
use_lora: true
rir_scp: null
rir_apply_prob: 1.0
noise_scp: null
noise_apply_prob: 1.0
noise_db_range: '13_15'
speech_volume_normalize: null
non_linguistic_symbols: null
preprocessor_conf:
speech_name: speech
text_name: text
# training related
seed: 2022
num_workers: 4
ngpu: 1
batch_type: numel
batch_bins: 1600000
accum_grad: 4
max_epoch: 70
patience: null
init: null
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
use_amp: true
optim: adam
optim_conf:
lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5