OWSM Fine-tuning for Speech Translation
OWSM Fine-tuning for Speech Translation
This Jupyter notebook provides a step-by-step guide on using the ESPnetEZ module to fine-tune the OWSM model. In this demonstration, we will leverage the MuST-C-v2
dataset (English to German subset) to fine-tune an OWSM model for the Speech Translation (ST) task.
In this notebook, we assume that you have already downloaded the MuST-C-v2
dataset and created the dump file using the recipe. If you haven't done this and are unfamiliar with the recipes provided in ESPnet, you can refer to the data preparation sections in the train_from_scratch.ipynb
or finetune_owsm.ipynb
notebooks in the ASR demos.
Author: Masao Someki @Masao-Someki
First, let's install espnet if you haven't it.
!pip install -U espnet
Then import necessary libraries and set several hyper parameters.
import torch
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
import espnetez as ez
FINETUNE_MODEL = "espnet/owsm_v3.1_ebf_base"
DATA_PATH = "./data"
DUMP_DIR = "./dump/raw"
EXP_DIR = "./exp/train_owsm_base_finetune"
STATS_DIR = "./exp/stats_owsm"
Model Preparation
Let's prepare the build_model_fn
function for the Trainer. We will use the OWSM-v3.1-ebf-base model, which has approximately 100 million parameters.
Please note that we cannot initialize the model outside of a function. When using a multi-GPU environment, we need to initialize the model for each GPU individually. Therefore, it is easier to run the model initialization function for each GPU rather than copying the model. Hence, we need to write the initialization code inside a function.
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def build_model_fn(args):
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
beam_size=10,
)
model = pretrained_model.s2t_model
model.train()
print(f'Trainable parameters: {count_parameters(model)}')
return model
The bulid_model_fn
function loads a pretrained speech-to-text model, and initializes it for training. It also prints the number of trainable parameters in the model.
Data Preparation
Since we assume that you have already generated the dump file using the recipe, we only need to write a simple dictionary for the data preparation.
However, please note that currently, ESPnetEZ does not support the combination of lambda functions and lists of data. Therefore, we need to prepare text_prev
manually, which is simply the <na>
symbol for all data entries.
## This should be executed outside of this notebook to prepare <na> for all data
from pathlib import Path
def rewrite(tp):
with open(tp / "text", "r", encoding="utf-8") as f:
lines = f.readlines()
nas = []
for line in lines:
id_utt, text = line.split(' ', maxsplit=1)
nas.append(f'{id_utt} <na>')
with open(tp / "text_na", "w", encoding="utf-8") as f:
f.write("\n".join(nas))
rewrite(Path("dump/raw/train.en-de_sp"))
rewrite(Path("dump/raw/dev.en-de"))
rewrite(Path("dump/raw/tst-COMMON.en-de"))
rewrite(Path("dump/raw/tst-HE.en-de"))
data_info = {
"speech": ["wav.scp", "kaldi_ark"],
"text": ["text.tc.de", "text"],
"text_prev": ["text_na", "text"],
"text_ctc": ["text", "text"],
}
Training Configuration
Now let's set up the training configuration for the OWSM finetuning. Basically all configurations are the same as the OWSM training, but we will change some parameters for this finetuning.
# Extract training config from the pretrained model.
from espnet2.bin.s2t_inference import Speech2Text
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
# category_sym="<en>",
beam_size=10,
device="cpu"
)
training_config = vars(pretrained_model.s2t_train_args)
del pretrained_model
# update config for finetuning if needed
# For the configuration, please refer to the last cell in this notebook.
finetune_config = ez.config.update_finetune_config(
"s2t",
training_config,
f"owsm_finetune_base.yaml"
)
# When you don't use yaml file, you can load finetune_config in the following way.
# task_class = ez.task.get_ez_task("s2t")
# default_config = task_class.get_default_config()
# training_config = default_config.update(your_config_in_dict)
# Currently ESPnetEZ does not work with the `multiple-iterator` mode.
finetune_config['multiple_iterator'] = False
Training
Now we have everything prepared, we can start training the OWSM model for ST task.
trainer = ez.Trainer(
task="s2t",
train_config=finetune_config,
train_dump_dir=f"{DUMP_DIR}/train.en-de_sp",
valid_dump_dir=f"{DUMP_DIR}/dev.en-de",
data_info=data_info,
build_model_fn=build_model_fn,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1,
)
trainer.collect_stats()
trainer.train()
Training configuration
seed: 2022
num_workers: 8
batch_type: unsorted
batch_size: 1
batch_bins: 500000
accum_grad: 1
max_epoch: 10
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 3
use_amp: true
ngpu: 1
optim: adamw
optim_conf:
lr: 0.0001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
use_preprocessor: false
preprocessor: default
preprocessor_conf:
fs: 16000
text_name:
- "text"