OWSM Fine-tuning for Spoken Language Understanding
OWSM Fine-tuning for Spoken Language Understanding
This Jupyter notebook provides a step-by-step guide on using the ESPnetEZ module to fine-tune the OWSM model. In this demonstration, we will use the SLURP
dataset (intent classification task) to fine-tune an OWSM model for a Spoken Language Understanding (SLU) task.
This demo will focus on how to add new tokens to the pre-trained OWSM model for the intent classification task. For other sections such as data preparation, training, and evaluation, please refer to the other notebooks.
In this notebook, we assume that you have already downloaded the SLURP
dataset and created the dump file using the recipe. If you haven't done this before and are unfamiliar with the recipes provided in ESPnet, you can refer to the data preparation sections in the train_from_scratch.ipynb
or finetune_owsm.ipynb
notebooks in the ASR demos.
Author: Masao Someki @Masao-Someki
First, let's install ESPnet if you haven't already:
!pip install -U espnet
Then import necessary libraries and set several hyper parameters.
import torch
import numpy as np
import librosa
from pathlib import Path
from espnet2.layers.create_adapter_fn import create_lora_adapter
import argparse
import espnetez as ez
FINETUNE_MODEL = "espnet/owsm_v3.1_ebf_base"
DATA_PATH = "./data"
DUMP_DIR = "./dump/raw"
STATS_DIR = "./exp/stats_owsm"
ADDITIONAL_SPECIAL_TOKENS = [
"<intent>"
]
Adding a New Token to the Pre-trained OWSM Tokenizer
In this section, we will add a new <intent>
token to the pre-trained OWSM tokenizer. We have prepared the add_special_tokens
function from the espnet2.preprocess
module to add new tokens to the pre-trained tokenizer, converter, and the Embedding layer.
We will use the Embedding layer in the build_model_fn
function to replace the pre-trained Embedding layer with the new one.
from espnet2.bin.s2t_inference import Speech2Text
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
# category_sym="<en>",
beam_size=10,
device="cpu"
)
tokenizer = pretrained_model.tokenizer
converter = pretrained_model.converter
# Add new <intent_cls> token after ST-related tokens
tokenizer, converter, _ = ez.preprocess.add_special_tokens(
tokenizer, converter, pretrained_model.s2t_model.decoder.embed[0],
ADDITIONAL_SPECIAL_TOKENS, insert_after="<st_zho>"
)
# And load configuration of pre-trained model. before deleting it.
training_config = vars(pretrained_model.s2t_train_args)
del pretrained_model
Data Preparation
To create text with the new <intent>
token, we need to modify the dump file generated by the recipe. Specifically, we want to change the text
format to:
<eng><intent><notimestamps> intent
To achieve this, we need to write a custom dataset class and data_info
with appropriate functions.
class CustomDataset:
def __init__(self, data_path, is_train=True):
self.data_path = data_path
if is_train:
data_path = f"{data_path}/train"
else:
data_path = f"{data_path}/devel"
self.data = {}
with open(f"{data_path}/wav.scp", "r") as f:
for line in f.readlines():
audio_id, audio_path = line.strip().split(maxsplit=1)
self.data[audio_id] = {
'audio_path': audio_path
}
with open(f"{data_path}/transcript", "r") as f:
for line in f.readlines():
audio_id, translated = line.strip().split(maxsplit=1)
self.data[audio_id]['transcript'] = translated
with open(f"{data_path}/text", "r") as f:
for line in f.readlines():
audio_id, intent, _ = line.strip().split(maxsplit=2)
self.data[audio_id]['intent'] = intent
self.keys = list(self.data.keys())
def __len__(self):
return len(self.keys)
def __getitem__(self, idx):
return {
'audio_path': self.data[idx]['audio_path'],
'intent': self.data[idx]['intent'],
'transcript': self.data[idx]['transcript']
}
def tokenize(text):
return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
data_info = {
"speech": lambda d : librosa.load(d['audio_path'], sr=16000)[0],
"text": lambda d : tokenize(f"<eng><intent><notimestamps>{d['intent']}"),
"text_prev": lambda d : tokenize("<na>"),
"text_ctc": lambda d : tokenize(d['transcript'].lower()),
}
Model Preparation
Let's prepare the build_model_fn
function for the Trainer. Inside the build_model_fn
function, we will replace the pre-trained Embedding layer to the new one.
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def build_model_fn(args):
pretrained_model = Speech2Text.from_pretrained(
FINETUNE_MODEL,
beam_size=10,
)
model = pretrained_model.s2t_model
model.train()
print(f'Trainable parameters: {count_parameters(model)}')
# Add new <intent> token
_, _, new_embedding = ez.preprocess.add_special_tokens(
pretrained_model.tokenizer,
pretrained_model.converter,
model.decoder.embed[0],
ADDITIONAL_SPECIAL_TOKENS,
insert_after="<st_zho>"
)
new_embedding.weight.requires_grad = True
model.decoder.embed[0] = new_embedding
# apply lora if you want.
# create_lora_adapter(model, target_modules=LORA_TARGET)
# print(f'Trainable parameters after LORA: {count_parameters(model)}')
return model
Training Configuration
Now let's set up the training configuration for the OWSM finetuning. Basically all configurations are the same as the OWSM training, but we will change some parameters for this finetuning.
# When you load your configuration from yaml file, you can write with update_finetune_config
finetune_config = ez.config.update_finetune_config(
"s2t",
training_config,
"owsm_finetune_base.yaml",
)
# When you don't use yaml file, you can load finetune_config in the following way.
# task_class = ez.task.get_ez_task("s2t")
# default_config = task_class.get_default_config()
# finetune_config = default_config.update(your_config_in_dict)
finetune_config['multiple_iterator'] = False
train_dataset = CustomDataset(data_path="./dump/raw", is_train=True)
dev_dataset = CustomDataset(data_path="./dump/raw", is_train=False)
train_dataset = ez.dataset.ESPnetEZDataset(train_dataset, data_info=data_info)
dev_dataset = ez.dataset.ESPnetEZDataset(dev_dataset, data_info=data_info)
Training
Now we have everything prepared, we can start training the OWSM model for ST task.
trainer = ez.Trainer(
task="s2t",
train_config=finetune_config,
train_dataset=train_dataset,
valid_dataset=dev_dataset,
data_info=data_info,
build_model_fn=build_model_fn,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1,
)
trainer.collect_stats()
trainer.train()
Finetuning configuration
seed: 2022
num_workers: 8
batch_type: unsorted
batch_size: 1
batch_bins: 500000
accum_grad: 1
max_epoch: 10
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 3
use_amp: true
ngpu: 1
optim: adamw
optim_conf:
lr: 0.0001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
use_preprocessor: false
preprocessor: default
preprocessor_conf:
fs: 16000
text_name:
- "text"