ASR + LLM Fine-tuning with ESPnetEZ
ASR + LLM Fine-tuning with ESPnetEZ
This Jupyter notebook provides a step-by-step guide on using the ESPnetEZ trainer to fine-tune ASR + LLM. In this demonstration, we will leverage the MuST-C-v2
dataset (English to German subset) to fine-tune this cascade Speech Translation (ST) system.
In this notebook, we assume that you have already downloaded the MuST-C-v2
dataset and created the dump file using the recipe. If you haven't done this and are unfamiliar with the recipes provided in ESPnet, you can refer to the data preparation sections in the train_from_scratch.ipynb
or finetune_owsm.ipynb
notebooks in the ASR demos.
Author: Masao Someki @Masao-Someki
Let's install espnet and transformers if you haven't installed it.
!pip install -U espnet
!pip install transformers
And import necessary libraries and set several hyperparameters.
import torch
import torch.nn as nn
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.train.dataset import kaldi_loader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import espnetez as ez
TRAIN_KEY = "huggingface_cascade"
FINETUNE_MODEL = "pyf98/librispeech_100_e_branchformer"
HF_MODEl = "google-t5/t5-base"
DATA_PATH = "./data"
DUMP_DIR = "./dump/raw"
EXP_DIR = f"./exp/train_{TRAIN_KEY}"
STATS_DIR = f"./exp/stats_{TRAIN_KEY}"
Data Preparation
In this demo, let's define our custom dataset using the prepared dump files. This dataset will load the audio and text data from the dump files and convert the text data for our training.
class CustomDataset:
def __init__(self, data_path, is_train=True):
self.data_path = data_path
if is_train:
data_path = f"{data_path}/train.en-de_sp"
else:
data_path = f"{data_path}/dev.en-de"
self.data = {}
with open(f"{data_path}/text.tc.de", "r") as f:
for line in f.readlines():
audio_id, translated = line.strip().split(maxsplit=1)
translated = translated.replace(" '", "'")\
.replace(" "", '"')\
.replace(" &", "&")
self.data[audio_id] = {
'translated': translated
}
with open(f"{data_path}/text", "r") as f:
for line in f.readlines():
audio_id, text = line.strip().split(maxsplit=1)
text = text.replace(" '", "'")\
.replace(" "", '"')\
.replace(" &", "&")
self.data[audio_id]['text'] = text
self.keys = list(self.data.keys())
self.loader = kaldi_loader(f"{data_path}/wav.scp")
def __len__(self):
return len(self.keys)
def __getitem__(self, idx):
# This output will be fed into the lambda function in `data_info`.
idx = int(idx)
return {
'speech': self.loader[self.keys[idx]].astype(np.float32),
'text': self.data[self.keys[idx]]['text'],
'translated': self.data[self.keys[idx]]['translated']
}
Model Preparation
Next, let's prepare the build_model_fn
function for the Trainer. We will define our custom model for ASR + LLM training. In this demo, we will fine-tune LLM with the translated text and ASR output.
We will use the ESPnetASRModel
for training. Currently, we don't have a specific class to support custom models, so we will leverage the existing ASR class.
The forward
method will take the output of the data_info
and _lengths
tensors, and output the loss for training. I have also added logging functionality to track the training progress inside the forward
method.
class CustomFinetuneModel(AbsESPnetModel):
def __init__(self, nbest=5, beam_size=10, log_every=500):
super().__init__()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.log_every = log_every
self.asr_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
nbest=nbest,
beam_size=beam_size,
device=device
)
self.lm = AutoModelForSeq2SeqLM.from_pretrained(
HF_MODEl,
device_map = device
)
self.lm_tokenizer = AutoTokenizer.from_pretrained(HF_MODEl)
self.log_stats = {
'loss': 0
}
self.iter_count = 0
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
*args,
**kwargs,
):
return {"feats": speech, "feats_lengths": speech_lengths}
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
# 1. ASR
asr_texts = self.asr_model(speech[0])[0][0]
asr_texts = "translate English to German: " + asr_texts.capitalize()
# compute hf loss
target_tokens = self.lm_tokenizer(
asr_texts, return_tensors="pt").input_ids.to(speech.device)
lm_output = self.lm(input_ids=target_tokens, labels=text)
# Add lm loss to ASR loss
loss = lm_output.loss
self.log_stats['loss'] += loss.item()
self.iter_count += 1
if self.iter_count % self.log_every == 0:
_loss = self.log_stats['loss'] / self.log_every
print(f"[{self.iter_count}] - loss: {_loss:.3f}")
self.log_stats['loss'] = 0.0
return loss, stats, None
Then let's define the data_info
and build_model_fn
functions for the Trainer.
lm_tokenizer = AutoTokenizer.from_pretrained(HF_MODEl)
data_info = {
"speech": lambda d : d['speech'],
"text": lambda d: lm_tokenizer(d['translated'].upper(),
return_tensors="np").input_ids[0],
}
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def build_model_fn(args):
model = CustomFinetuneModel(log_every=20)
return model
Training
Finally, let's define the training configuration, instanciate the dataset and trainer and start the training!
from espnet2.bin.asr_inference import Speech2Text
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
beam_size=10,
device="cpu"
)
training_config = vars(pretrained_model.asr_train_args)
del pretrained_model
# For the configuration, please refer to the last cell in this notebook.
finetune_config = ez.config.update_finetune_config(
"asr",
training_config,
"owsm_finetune_base.yaml"
)
finetune_config['multiple_iterator'] = False
# When you don't use yaml file, you can load finetune_config in the following way.
# task_class = ez.task.get_ez_task("asr")
# default_config = task_class.get_default_config()
# training_config = default_config.update(your_config_in_dict)
train_dataset = CustomDataset(data_path="./dump/raw", is_train=True)
dev_dataset = CustomDataset(data_path="./dump/raw", is_train=False)
train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
dev_dataset = ez.dataset.ESPnetEZDataset(dev_dataset, data_info=data_info)
trainer = ez.Trainer(
task="asr",
train_config=finetune_config,
train_dataset=train_dataset,
valid_dataset=dev_dataset,
data_info=data_info,
build_model_fn=build_model_fn,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1,
)
trainer.collect_stats()
trainer.train()
Training configuration
seed: 2022
num_workers: 8
batch_type: unsorted
batch_size: 1
batch_bins: 500000
accum_grad: 1
max_epoch: 10
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 3
use_amp: true
ngpu: 1
optim: adamw
optim_conf:
lr: 0.0001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
use_preprocessor: false
preprocessor: default
preprocessor_conf:
fs: 16000
text_name:
- "text"