Sample demo for ESPnet-EZ!
Sample demo for ESPnet-EZ!
In this notebook, we will demonstrate how to train an Automatic Speech Recognition (ASR) model using the Librispeech-100 dataset. The process in this notebook follows the same dataset preparation approach as the kaldi-style dataset. If you are interested in fine-tuning pretrained models, please refer to the libri100_finetune.ipynb file.
Before proceeding, please ensure that you have already downloaded the Librispeech-100 dataset from OpenSLR and have placed the data in a directory of your choice. In this notebook, we assume that you have stored the dataset in the /hdd/dataset/
directory. If your dataset is located in a different directory, please make sure to replace /hdd/dataset/
with the actual path to your dataset.
Author: Masao Someki @Masao-Someki
Data Preparation
This notebook follows the data preparation steps outlined in asr.sh
. Initially, we will create a dump file to store information about the data, including the data ID, audio path, and transcriptions.
ESPnet-EZ supports various types of datasets, including:
Dictionary-based dataset with the following structure:
{ "data_id": { "speech": path_to_speech_file, "text": transcription } }
List of datasets with the following structure:
[ { "speech": path_to_speech_file, "text": transcription } ]
If you choose to use a dictionary-based dataset, it's essential to ensure that each data_id
is unique. ESPnet-EZ also accepts a dump file that may have already been created by asr.sh
. However, in this notebook, we will create the dump file from scratch.
# Need to install espnet if you don't have it
%pip install -U espnet
Now, let's create dump files!
Please note that you will need to provide a dictionary to specify the file path and type for each data. This dictionary should have the following format:
{
"data_name": ["dump_file_name", "dump_format"]
}
import glob
import os
import espnetez as ez
DUMP_DIR = "./dump/libri100"
LIBRI_100_DIRS = [
["/hdd/database/librispeech-100/LibriSpeech/train-clean-100", "train"],
["/hdd/database/librispeech-100/LibriSpeech/dev-clean", "dev-clean"],
["/hdd/database/librispeech-100/LibriSpeech/dev-other", "dev-other"],
]
data_info = {
"speech": ["wav.scp", "sound"],
"text": ["text", "text"],
}
def create_dataset(data_dir):
dataset = {}
for chapter in glob.glob(os.path.join(data_dir, "*/*")):
text_file = glob.glob(os.path.join(chapter, "*.txt"))[0]
with open(text_file, "r") as f:
lines = f.readlines()
ids_text = {
line.split(" ")[0]: line.split(" ", maxsplit=1)[1].replace("\n", "")
for line in lines
}
audio_files = glob.glob(os.path.join(chapter, "*.wav"))
for audio_file in audio_files:
audio_id = os.path.basename(audio_file)[: -len(".wav")]
dataset[audio_id] = {"speech": audio_file, "text": ids_text[audio_id]}
return dataset
for d, n in LIBRI_100_DIRS:
dump_dir = os.path.join(DUMP_DIR, n)
if not os.path.exists(dump_dir):
os.makedirs(dump_dir)
dataset = create_dataset(d)
ez.data.create_dump_file(dump_dir, dataset, data_info)
For the validation files, you have two directories: dev-clean
and dev-other
. To create a unified dev dataset, you can use the ez.data.join_dumps
function.
ez.data.join_dumps(
["./dump/libri100/dev-clean", "./dump/libri100/dev-other"], "./dump/libri100/dev"
)
Now you have dataset files in the dump
directory. It looks like this:
wav.scp
1255-138279-0008 /hdd/database/librispeech-100/LibriSpeech/dev-other/1255/138279/1255-138279-0008.flac
1255-138279-0022 /hdd/database/librispeech-100/LibriSpeech/dev-other/1255/138279/1255-138279-0022.flac
text
1255-138279-0008 TWO THREE
1255-138279-0022 IF I SAID SO OF COURSE I WILL
Train sentencepiece model
To train a SentencePiece model, we require a text file for training. Let's begin by creating the training file.
# generate training texts from the training data
# you can select several datasets to train sentencepiece.
ez.preprocess.prepare_sentences(["dump/libri100/train/text"], "dump/spm")
ez.preprocess.train_sentencepiece(
"dump/spm/train.txt",
"data/bpemodel",
vocab_size=5000,
)
Configure Training Process
For configuring the training process, you can utilize the configuration files already provided by ESPnet contributors. To use a configuration file, you'll need to create a YAML file on your local machine. For instance, you can use the e-branchformer config.
In my case, I've made a modification to the batch_bins
parameter, changing it from 16000000
to 1600000
to run training on my GPU (RTX2080ti).
Training
To prepare the stats file before training, you can execute the collect_stats
method. This step is required before the training process and ensuring accurate statistics for the model.
import espnetez as ez
EXP_DIR = "exp/train_asr_branchformer_e24_amp"
STATS_DIR = "exp/stats"
# load config
# For the configuration, please refer to the last cell in this notebook.
training_config = ez.config.from_yaml(
"asr",
"train.yaml",
)
preprocessor_config = ez.utils.load_yaml("preprocess.yaml")
training_config.update(preprocessor_config)
with open(preprocessor_config["token_list"], "r") as f:
training_config["token_list"] = [t.replace("\n", "") for t in f.readlines()]
# When you don't use yaml file, you can load finetune_config in the following way.
# task_class = ez.task.get_ez_task("asr")
# default_config = task_class.get_default_config()
# training_config = default_config.update(your_config_in_dict)
# Define the Trainer class
trainer = ez.Trainer(
task='asr',
train_config=training_config,
train_dump_dir="dump/libri100/train",
valid_dump_dir="dump/libri100/dev",
data_info=data_info,
output_dir=EXP_DIR,
stats_dir=STATS_DIR,
ngpu=1,
)
trainer.collect_stats()
Finally, we are ready to begin the training process!
trainer.train()
Inference
You can just use the inference API of the ESPnet.
import librosa
from espnet2.bin.asr_inference import Speech2Text
m = Speech2Text(
"./exp/train_asr_branchformer_e24_amp/config.yaml",
"./exp/train_asr_branchformer_e24_amp/valid.acc.best.pth",
beam_size=10
)
with open("./dump/libri100/dev/wav.scp", "r") as f:
sample_path = f.readlines()[0]
y, sr = librosa.load(sample_path.split()[1], sr=16000, mono=True)
output = m(y)
print(output[0][0])
Training configuration
# Trained with A40 (48 GB) x 1 GPUs.
encoder: e_branchformer
encoder_conf:
output_size: 256
attention_heads: 4
attention_layer_type: rel_selfattn
pos_enc_layer_type: rel_pos
rel_pos_type: latest
cgmlp_linear_units: 1024
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
layer_drop_rate: 0.0
linear_units: 1024
positionwise_layer_type: linear
use_ffn: true
macaron_ffn: true
merge_conv_kernel: 31
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
layer_drop_rate: 0.0
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
frontend_conf:
n_fft: 512
win_length: 400
hop_length: 160
seed: 2022
num_workers: 4
batch_type: numel
batch_bins: 1600000
accum_grad: 4
max_epoch: 70
patience: null
init: null
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
use_amp: true
optim: adam
optim_conf:
lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
preprocess.yaml
use_preprocessor: true
token_type: bpe
bpemodel: data/bpemodel/bpe.model
rir_scp: null
rir_apply_prob: 1.0
noise_scp: null
noise_apply_prob: 1.0
noise_db_range: '13_15'
speech_volume_normalize: null
non_linguistic_symbols: null
cleaner: null
g2p: null
preprocessor: default
preprocessor_conf:
speech_name: speech
text_name: text
token_list: data/bpemodel/tokens.txt