CMU 11751/18781 Fall 2022: ESPnet Tutorial
CMU 11751/18781 Fall 2022: ESPnet Tutorial
ESPnet is a widely-used end-to-end speech processing toolkit. It has supported various speech processing tasks. ESPnet uses PyTorch as a main deep learning engine, and also follows Kaldi style data processing, feature extraction/format, and recipes to provide a complete setup for speech recognition and other speech processing experiments.
Main references:
- ESPnet repository
- ESPnet documentation
- ESPnet tutorial in Speech Recognition and Understanding (Fall 2021)
- Recitation in Multilingual NLP (Spring 2022)
Author: Yifan Peng (yifanpen@andrew.cmu.edu)
❗Important Notes❗
- We are using Colab to show the demo. However, Colab has some constraints on the total GPU runtime. If you use too much GPU, you may fail to connect to a GPU backend for some time.
- There are multiple in-class checkpoints ✅ throughout this tutorial. There will also be some after-class excersices 📗 after the tutorial. Your participation points are based on these tasks. Please try your best to follow all the steps! If you encounter issues, please notify the TAs as soon as possible so that we can make an adjustment for you.
- Please submit PDF files of your completed notebooks to Gradescope. You can print the notebook using
File -> Print
in the menu bar. - This tutorial covers the basics of ESPnet, which will be the foundation of the next tutorial on Wednesday.
Objectives
After this tutorial, you are expected to know:
- How to run existing recipes (data prep, training, inference and scoring) in ESPnet2
- How to change the training and decoding configurations
- How to create a new recipe from scratch
- Where to find resources if you encounter an issue
Useful links
- Installation https://espnet.github.io/espnet/installation.html
- Usage https://espnet.github.io/espnet/espnet2_tutorial.html
Install ESPnet
This is a full installation method to perform data preprocessing, training, inference, scoring, and so on.
We prepare various ways of installation. Please read https://espnet.github.io/espnet/installation.html#step-2-installation-espnet for more details.
Function to print date and time
We first define a function to print the current date and time, which will be used in multiple places below.
def print_date_and_time():
from datetime import datetime
import pytz
now = datetime.now(pytz.timezone("America/New_York"))
print("=" * 60)
print(f' Current date and time: {now.strftime("%m/%d/%Y %H:%M:%S")}')
print("=" * 60)
# example output
print_date_and_time()
Check GPU type
Let's check the GPU type of this allocated environment.
!nvidia-smi
Download ESPnet
We use git clone
to download the source code of ESPnet and then go to a specific commit.
Important: In other versions of ESPnet, you may encounter errors related to imcompatible package versions (numba
). Please use the same commit to avoid such issues.
# It takes a few seconds
!git clone --depth 5 https://github.com/espnet/espnet
# We use a specific commit just for reproducibility.
%cd /content/espnet
!git checkout 3a22d1584317ae59974aad62feab8719c003ae05
Setup Python environment based on anaconda
There are several other installation methods, but we highly recommend the anaconda-based one.
# It takes 30 seconds
%cd /content/espnet/tools
!./setup_anaconda.sh anaconda espnet 3.9
Install ESPnet
This step installs PyTorch and other required tools.
We specify CUDA_VERSION=11.6
for PyTorch 1.12.1. We also support many other versions. Please check https://github.com/espnet/espnet/blob/master/tools/installers/install_torch.sh for the detailed version list.
# It may take 12 minutes
%cd /content/espnet/tools
!make TH_VERSION=1.12.1 CUDA_VERSION=11.6
If other listed packages are necessary, install any of them using
. ./activation_python.sh && ./installers/install_xxx.sh
We show two examples, although they are not used in this demo.
# s3prl and fairseq are necessary if you want to use self-supervised pre-trained models
# It takes 50s
%cd /content/espnet/tools
!. ./activate_python.sh && ./installers/install_s3prl.sh # install s3prl to use SSLRs
!. ./activate_python.sh && ./installers/install_fairseq.sh # install s3prl to use Wav2Vec2 / HuBERT model series
Check installation
Now let's make sure torch
, torch cuda
, and espnet
are successfully installed.
...
[x] torch=1.12.1
[x] torch cuda=11.6
[x] torch cudnn=8302
...
[x] espnet=202207
...
✅ Checkpoint 1 (1 point)
Print the output of check_install.py
.
%cd /content/espnet/tools
!. ./activate_python.sh && python3 check_install.py | head -n 40
# NOTE: Checkpoint 1
print_date_and_time()
Run an existing recipe
ESPnet has a number of recipes (130 recipes on Sep. 11, 2022). Please refer to https://github.com/espnet/espnet/blob/master/egs2/README.md for a complete list.
Please also check the general usage of the recipe in https://espnet.github.io/espnet/espnet2_tutorial.html#recipes-using-espnet2
CMU AN4 recipe
In this tutorial, we will use the CMU an4
recipe. This is a small-scale speech recognition task mainly used for testing.
First, let's go to the recipe directory.
%cd /content/espnet/egs2/an4/asr1
!ls
egs2/an4/asr1/
- conf/ # Configuration files for training, inference, etc.
- scripts/ # Bash utilities of espnet2
- pyscripts/ # Python utilities of espnet2
- steps/ # From Kaldi utilities
- utils/ # From Kaldi utilities
- db.sh # The directory path of each corpora
- path.sh # Setup script for environment variables
- cmd.sh # Configuration for your backend of job scheduler
- run.sh # Entry point
- asr.sh # Invoked by run.sh
ESPnet is designed for various use cases (local machines or cluster machines) based on Kaldi tools. If you use it in the cluster machines, please also check https://kaldi-asr.org/doc/queue.html
The main stages can be parallelized by various jobs.
!cat run.sh
run.sh
calls asr.sh
, which completes the entire speech recognition experiments, including data preparation, training, inference, and scoring. They are separated into multiple stages (totally 16).
Instead of executing the entire pipeline by run.sh
, let's run it stage-by-stage to understand the process in each stage.
Data preparation
Stage 1: Data preparation: download raw data, split the entire set into train/dev/test, and prepare them in the Kaldi format
Note that --stage <N>
is to start from this stage and --stop_stage <N>
is to stop after this stage. We also need to specify the train, dev and test sets.
# a few seconds
!./asr.sh --stage 1 --stop_stage 1 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test"
After this stage is finished, please check the newly created data
directory:
!ls data
In this recipe, we use train_nodev
as a training set, train_dev
as a validation set (monitor the training progress by checking the validation score). We also use test
and train_dev
sets for the final speech recognition evaluation.
Let's check one of the training data directories:
!ls -1 data/train_nodev/
These are the speech and corresponding text and speaker information in the Kaldi format. To understand their meanings, please check https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE#about-kaldi-style-data-directory.
Please also check the official documentation of Kaldi: https://kaldi-asr.org/doc/data_prep.html
spk2utt # Speaker information
text # Transcription file
utt2spk # Speaker information
wav.scp # Audio file
Stage 2: Speed perturbation (one of the data augmentation methods)
We do not use speed perturbation for this demo. But you can turn it on by adding an argument --speed_perturb_factors "0.9 1.0 1.1"
to the shell script.
Note that we perform speed perturbation and save the augmented data in the disk before training. Another approach is to perform data augmentation during training, such as SpecAug.
!./asr.sh --stage 2 --stop_stage 2 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test"
Stage 3: Format wav.scp: data/ -> dump/raw
We dump the data with specified format (flac in this case) for the efficient use of the data.
# ====== Recreating "wav.scp" ======
# Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
# shouldn't be used in training process.
# "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
# and it can also change the audio-format and sampling rate.
# If nothing is need, then format_wav_scp.sh does nothing:
# i.e. the input file format and rate is same as the output.
Note that --nj <N>
means the number of CPU jobs. Please set it appropriately by considering your CPU resources and disk access.
# 25 seconds
!./asr.sh --stage 3 --stop_stage 3 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --nj 4
Stage 4: Remove long/short data: dump/raw/org -> dump/raw
Too long and too short audio data are harmful for efficient training. Those utterances are removed for training. But for inference and scoring, we still use the full data, which is important for fair comparison.
!./asr.sh --stage 4 --stop_stage 4 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test"
Stage 5: Generate token_list from dump/raw/train_nodev/text using BPE.
This is important for text processing. Here, we make a dictionary simply using the English characters. We use the sentencepiece
toolkit developed by Google.
!./asr.sh --stage 5 --stop_stage 5 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test"
Let's check the content of the dictionary. There are several special symbols, e.g.,
<blank> used for CTC
<unk> unknown symbols do not appear in the training data
<sos/eos> start and end sentence symbols
✅ Checkpint 2 (1 point)
Print the generated token list.
!cat data/token_list/bpe_unigram30/tokens.txt
# NOTE: Checkpoint 2
print_date_and_time()
Language modeling (skipped in this tutorial)
Stages 6--9: Stages related to language modeling.
We skip the language modeling part in the recipe (stages 6 -- 9) in this tutorial.
End-to-end ASR
Before training, we need to set the training configs including the front-end, encoder and decoder, optimizer, scheduler, specaug, etc. These configs are usually specified in .yaml
files: /content/espnet/egs2/an4/asr1/conf/train_asr_xxx.yaml, but you can also overwrite them in the command line.
In this example, we will train a small Transformer model. ESPnet also supports other encoder types, such as RNN (LSTM/GRU), Conformer, and Branchformer.
Please do the following:
- Create a new config file
train_asr_demo_transformer.yaml
in the directory /content/espnet/egs2/an4/asr1/conf/ - Copy the following lines to the new config
batch_type: folded
batch_size: 64
accum_grad: 1 # gradient accumulation steps
max_epoch: 100
patience: none
init: xavier_uniform
best_model_criterion: # criterion to save best models
- - valid
- acc
- max
keep_nbest_models: 10 # save nbest models and average these checkpoints
use_amp: true # whether to use automatic mixed precision
num_att_plot: 0 # do not save attention plots to save time in the demo
num_workers: 2 # number of workers in dataloader
encoder: transformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
model_conf:
ctc_weight: 0.3 # joint CTC/attention training
lsm_weight: 0.1 # label smoothing weight
length_normalized_loss: false
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # linearly increase and exponentially decrease
scheduler_conf:
warmup_steps: 800
Stage 10: ASR collect stats: train_set=dump/raw/train_nodev, valid_set=dump/raw/train_dev
- We estimate the mean and variance of the data to normalize the data.
- We also collect the information of input and output lengths for the efficient mini batch creation.
ESPnet supports various methods to create batches, please refer to https://espnet.github.io/espnet/espnet2_training_option.html#change-mini-batch-type.
# 15 seconds
!./asr.sh --stage 10 --stop_stage 10 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --nj 4 --asr_config conf/train_asr_demo_transformer.yaml
Stage 11: ASR Training: train_set=dump/raw/train_nodev, valid_set=dump/raw/train_dev
This is the main training loop.
During training, please monitor the following files
- log file /content/espnet/egs2/an4/asr1/exp/asr_train_asr_demo_transformer_raw_bpe30/train.log
- loss /content/espnet/egs2/an4/asr1/exp/asr_train_asr_demo_transformer_raw_bpe30/images/loss.png
- accuracy /content/espnet/egs2/an4/asr1/exp/asr_train_asr_demo_transformer_raw_bpe30/images/acc.png
Good examples look like this:
However, bad examples (with too large lr) are like:
ESPnet supports tensorboard
, you can monitor the training status using it.
# Load the TensorBoard notebook extension
%load_ext tensorboard
# Launch tensorboard before training
%tensorboard --logdir /content/espnet/egs2/an4/asr1/exp
# It takes 12 minutes with a single T4 GPU.
!./asr.sh --stage 11 --stop_stage 11 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --ngpu 1 --asr_config conf/train_asr_demo_transformer.yaml
The training log contains all information of the current experiment. Please check it to understand the training status. If your training job fails, this file will show the error messages.
✅ Checkpoint 3 (1 point)
Print the training log.
# NOTE: Checkpoint 3
!tail -20 exp/asr_train_asr_demo_transformer_raw_bpe30/train.log
print_date_and_time()
Stage 12: Decoding
We need to add
--use_lm false
since we skip the language model.--asr_exp exp/asr_train_asr_demo_transformer_raw_bpe30
specifies the experiment directory name.--inference_nj <N>
specifies the number of inference jobs.We can enable GPU decoding by setting
--gpu_inference true
. Otherwise CPU will be used for decoding.--inference_config conf/decode_asr.yaml
specifies the decoding config file.
Let's monitor the log /content/espnet/egs2/an4/asr1/exp/asr_train_asr_demo_transformer_raw_bpe30/decode_asr_asr_model_valid.acc.ave/train_dev/logdir/asr_inference.1.log
# It would take 3 minutes with a single T4 GPU.
!./asr.sh --use_lm false --gpu_inference true --inference_nj 1 --stage 12 --stop_stage 12 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --asr_exp exp/asr_train_asr_demo_transformer_raw_bpe30 --inference_config conf/decode_asr.yaml
Stage 13: Scoring
You can find word error rate (WER), character error rate (CER), etc. for each test set.
!./asr.sh --stage 13 --stop_stage 13 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --use_lm false --asr_exp exp/asr_train_asr_demo_transformer_raw_bpe30 --inference_config conf/decode_asr.yaml
You can also check the break down of the word error rate in /content/espnet/egs2/an4/asr1/exp/asr_train_asr_demo_transformer_raw_bpe30/decode_asr_asr_model_valid.acc.ave/train_dev/score_wer/result.txt
✅ Checkpoint 4 (1 point)
Print the scoring results.
# NOTE: Checkpoint 4
!cat exp/asr_train_asr_demo_transformer_raw_bpe30/RESULTS.md
print_date_and_time()
❗Checkpoint Submission: If you have completed all the four checkpoints, please print your notebook and submit it to Gradescope.
How to change the configs?
Let's revisit the configs, since this is probably the most important part to improve the performance.
Config file based
All training options are changed in the config file.
Pleae check https://espnet.github.io/espnet/espnet2_training_option.html
Let's first check config files prepared in the an4
recipe
- LSTM-based E2E ASR /content/espnet/egs2/an4/asr1/conf/train_asr_rnn.yaml
- Transformer based E2E ASR /content/espnet/egs2/an4/asr1/conf/train_asr_transformer.yaml
You can run
RNN
./asr.sh --stage 10 \
--train_set train_nodev \
--valid_set train_dev \
--test_sets "train_dev test" \
--nj 4 \
--inference_nj 4 \
--use_lm false \
--asr_config conf/train_asr_rnn.yaml
Transformer
./asr.sh --stage 10 \
--train_set train_nodev \
--valid_set train_dev \
--test_sets "train_dev test" \
--nj 4 \
--inference_nj 4 \
--use_lm false \
--asr_config conf/train_asr_transformer.yaml
You can also find various configs in other recipes espnet/egs2/*/asr1/conf/
, including
- Conformer
egs2/librispeech/asr1/conf/tuning/train_asr_conformer10_hop_length160.yaml
- Branchformer
egs2/librispeech/asr1/conf/tuning/train_asr_branchformer_hop_length160_e18_linear3072.yaml
Command line argument based
You can also customize it by passing the command line arguments, e.g.,
./run.sh --stage 10 --asr_args "--model_conf ctc_weight=0.3"
./run.sh --stage 10 --asr_args "--optim_conf lr=0.1"
This approach has a highest priority. Thus, the arguments passed in the command line will overwrite those defined in the config file. This is convenient if you only want to change a few arguments.
Please refer to https://espnet.github.io/espnet/espnet2_tutorial.html#change-the-configuration-for-training for more details.
📗 Exercise 1 (1 point bonus)
Run training, inference and scoring on AN4 using a new config. If you achieve a better character error rate (CER) than this following one, you can get a bonus point. I suggest tuning the total number of epochs, learning rate, warmup steps or data augmentation (speed perturbation, SpecAug) to improve the result.
This AN4 dataset is very small, so the result can be unstable even for the same configs. Generally, we should compare different methods using a large dataset such as LibriSpeech 960h.
Here is an example config using Branchformer (Peng et al, ICML 2022). Only the encoder is changed and the others are identical.
Similarly, we create a config file named train_asr_demo_branchformer.yaml
and start training.
encoder: branchformer
encoder_conf:
output_size: 256
use_attn: true
attention_heads: 4
attention_layer_type: rel_selfattn
pos_enc_layer_type: rel_pos
rel_pos_type: latest
use_cgmlp: true
cgmlp_linear_units: 1024
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
merge_method: concat
cgmlp_weight: 0.5 # used only if merge_method is "fixed_ave"
attn_branch_drop_rate: 0.0 # used only if merge_method is "learned_ave"
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
stochastic_depth_rate: 0.0
My result is shown below:
## asr_train_asr_demo_branchformer_raw_bpe30
### WER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_asr_model_valid.acc.ave/test|130|773|85.4|11.3|3.4|0.4|15.0|46.9|
|decode_asr_asr_model_valid.acc.ave/train_dev|100|591|77.3|15.7|6.9|0.7|23.4|62.0|
### CER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_asr_model_valid.acc.ave/test|130|2565|93.5|2.9|3.6|1.2|7.8|46.9|
|decode_asr_asr_model_valid.acc.ave/train_dev|100|1915|87.8|5.0|7.2|1.8|14.0|62.0|
### TER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_asr_model_valid.acc.ave/test|130|2695|93.8|2.8|3.4|1.2|7.4|46.9|
|decode_asr_asr_model_valid.acc.ave/train_dev|100|2015|88.4|4.8|6.8|1.7|13.3|62.0|
# Run multiple stages
!./asr.sh --stage 10 --stop_stage 13 --train_set train_nodev --valid_set train_dev --test_sets "train_dev test" --nj 4 --ngpu 1 --use_lm false --gpu_inference true --inference_nj 1 --asr_config conf/train_asr_demo_branchformer.yaml --inference_config conf/decode_asr.yaml
# NOTE: Exercise 1 Result 1
!scripts/utils/show_asr_result.sh exp
print_date_and_time()
# NOTE: Exercise 1 Result 2
from IPython.display import Image, display
display(
Image('exp/asr_train_asr_demo_transformer_raw_bpe30/images/acc.png', width=400),
Image('exp/asr_train_asr_demo_branchformer_raw_bpe30/images/acc.png', width=400),
)
print_date_and_time()
Make a new recipe
Please carefully read the document: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE to understand how to create a new recipe. The major part is to prepare the data
directory, which organizes the processed data in the Kaldi format. The other parts can be accomplished by executing some shared scripts.
📗 Exercise 2 (1 point)
We use the TIDIGITS dataset and we select part of it which contains around 2000 utterances for training and another 2000 for testing. Note that this dataset is provided only for private use. Please do not share it elsewhere.
You need to finish the entire pipeline from data preparation to scoring.
1. Create a new directory
We need to create a new directory in egs2/
for the new dataset. This will automatically create several other files and directories.
asr.sh cmd.sh conf db.sh local path.sh pyscripts scripts steps utils
%cd /content/espnet
!egs2/TEMPLATE/asr1/setup.sh egs2/tidigits/asr1
%cd egs2/tidigits/asr1
!ls
2. Download data
Please download the compressed dataset from Google Drive (removed in this public version) and decompress it. Then, specify the absolute path to the dataset in db.sh
as follows:
...
TIDIGITS=/content/espnet/egs2/tidigits/asr1/downloads/TIDIGITS_children_boy # our newly added path
...
!mkdir downloads
%cd downloads
!gdown <FILE_ID>
!tar -xzf TIDIGITS_children_boy.tar.gz && ls TIDIGITS_children_boy
%cd /content/espnet/egs2/tidigits/asr1
!echo "" >> db.sh
!echo "TIDIGITS=/content/espnet/egs2/tidigits/asr1/downloads/TIDIGITS_children_boy" >> db.sh
3. Finish the script for data preparation
In Stage 1, asr.sh
calls local/data.sh
to prepare the dataset in the Kaldi format. So you need to do the following things.
- Create a file
local/data.sh
with the following content. This is task specific, but here we have prepared it for you.
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SECONDS=0
stage=1
stop_stage=100
log "$0 $*"
. utils/parse_options.sh
. ./db.sh
. ./path.sh
. ./cmd.sh
if [ $# -ne 0 ]; then
log "Error: No positional arguments are required."
exit 2
fi
if [ -z "${TIDIGITS}" ]; then
log "Fill the value of 'TIDIGITS' of db.sh"
exit 1
fi
train_set="train_nodev"
train_dev="train_dev"
ndev_utt=200
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1: Data preparation"
mkdir -p data/{train,test}
if [ ! -f ${TIDIGITS}/readme.1st ]; then
echo Cannot find TIDIGITS root! Exiting...
exit 1
fi
# Prepare data in the Kaldi format, including three files:
# text, wav.scp, utt2spk
python3 local/data_prep.py ${TIDIGITS} sph2pipe
for x in test train; do
for f in text wav.scp utt2spk; do
sort data/${x}/${f} -o data/${x}/${f}
done
utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > "data/${x}/spk2utt"
done
# make a dev set
utils/subset_data_dir.sh --first data/train "${ndev_utt}" "data/${train_dev}"
n=$(($(wc -l < data/train/text) - ndev_utt))
utils/subset_data_dir.sh --last data/train "${n}" "data/${train_set}"
fi
log "Successfully finished. [elapsed=${SECONDS}s]"
- Create a python script
local/data_prep.py
which is used by the previous shell script. Similarly, we have provided most of the code for you. You only need to finish a few lines of code (see the TODO tag).
import os
import glob
import sys
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python data_prep.py [root] [sph2pipe]")
sys.exit(1)
root = sys.argv[1]
sph2pipe = sys.argv[2]
for x in ["train", "test"]:
# We only use the data from boy children
all_audio_list = glob.glob(
os.path.join(root, "data", "children", x, "boy", "*", "*.wav")
)
with open(os.path.join("data", x, "text"), "w") as text_f, open(
os.path.join("data", x, "wav.scp"), "w"
) as wav_scp_f, open(
os.path.join("data", x, "utt2spk"), "w"
) as utt2spk_f:
for audio_path in all_audio_list:
filename = os.path.basename(audio_path) # "o73a.wav" etc
speaker = os.path.basename(os.path.dirname(audio_path)) # "lc", "sk", etc
transcript = " ".join(list(filename[:-5])) # "o73" -> "o 7 3"
uttid = f"{speaker}-{filename[:-4]}" # "sk-o73a"
wav_scp_f.write(
f"{uttid} {sph2pipe} -f wav -p -c 1 {audio_path} |\n"
)
### TODO: write the other files in the Kaldi format
!touch local/data.sh && chmod +x local/data.sh
!touch local/data_prep.py
## TODO: copy the script and finish it
4. Create a script as the entry point
Now let's create a shell script run.sh
as the entry point. You can directly use the following one.
What are the differences compared to the AN4 recipe?
- We use words as modeling units.
- We pass decoding arguments in the command line, see
inference_args
. - We perform data prep, training, decoding and scoring in this single job.
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
./asr.sh \
--stage 1 \
--stop_stage 13 \
--nj 4 \
--ngpu 1 \
--gpu_inference true \
--inference_nj 1 \
--use_lm false \
--lang en \
--token_type word \
--asr_config conf/train_asr_demo_branchformer.yaml \
--inference_args "--beam_size 10 --ctc_weight 0.3" \
--train_set train_nodev \
--valid_set train_dev \
--test_sets "train_dev test" \
--bpe_train_text "data/train_nodev/text" \
--lm_train_text "data/train_nodev/text" "$@"
!touch run.sh && chmod +x run.sh
## TODO: copy and paste the code
5. Create the training config
We can use the previous Branchformer config train_asr_demo_branchformer.yaml
. It should work well.
Note that we pass the decoding configs in the command line instead of using a separate file.
!touch conf/train_asr_demo_branchformer.yaml
## TODO: copy and paste the config from an4
6. Execute the script
Now we are ready to launch the job. We start from Stage 1 to Stage 13.
Again, we can monitor the status using Tensorboard.
# Load the TensorBoard notebook extension
%load_ext tensorboard
# Launch tensorboard before training
%tensorboard --logdir /content/espnet/egs2/tidigits/asr1/exp
# It takes 30 minutes with a single T4 GPU
!./run.sh
7. Print the results
We print the results and display the figure. Please modify the path based on your situation.
The example output looks like the following:
## asr_train_asr_demo_branchformer_raw_en_word
### WER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|inference_beam_size10_ctc_weight0.3_asr_model_valid.acc.ave/test|1925|6325|98.4|1.2|0.3|0.3|1.9|5.9|
|inference_beam_size10_ctc_weight0.3_asr_model_valid.acc.ave/train_dev|200|665|99.8|0.2|0.0|0.3|0.5|1.5|
### CER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|inference_beam_size10_ctc_weight0.3_asr_model_valid.acc.ave/test|1925|10725|98.9|0.7|0.4|0.4|1.5|5.9|
|inference_beam_size10_ctc_weight0.3_asr_model_valid.acc.ave/train_dev|200|1130|99.9|0.1|0.0|0.4|0.4|1.5|
### TER
|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
# NOTE: Exercise 2
# Remember to modify the path before running this cell!
!cat exp/asr_train_asr_demo_branchformer_raw_en_word/RESULTS.md
from IPython.display import Image, display
display(Image('exp/asr_train_asr_demo_branchformer_raw_en_word/images/acc.png', width=400))
print_date_and_time()
Additional resources
Finally, we list some resources for more advanced usage of ESPnet. Please check them after the tutorial.
Fine-tune a pre-trained model
We can load pre-trained weights into a new model and then fine-tune it on our target dataset. There is a separate tutorial about fine-tuning: https://espnet.github.io/espnet/notebook/espnet2_asr_transfer_learning_demo.html
Notes:
- Here, I do not use the phrase "transfer learning", because transfer learning is really a large research area. Fine-tuning is just one type of transfer learning. However, these concepts are usually not distinguished in the current documentation. So, do not get confused about the terminology.
- There are other ways to pass the pre-trained model path, which require modifications of the scripts
asr.sh
. For example, you can comment out thepretrained_model
andignore_init_mismatch
inasr.sh
and set them in the training config instead.
Use self-supervised pre-trained models as the front-end
In addition to the log Mel filterbank features computed by signal processing algorithms, ESPnet also supports self-supervised speech representations extracted from large pre-trained models such as wav2vec 2.0 and HuBERT. These models are pre-trained on large amounts of unlabeled audio data, so the representations are very powerful. This approach is especially useful for low-resource applications.
Please check recipes that use self-supervised pre-trained models, e.g., https://github.com/espnet/espnet/tree/master/egs2/librispeech/asr1#self-supervised-learning-features-wavlm_large-conformer-utt_mvn-with-transformer-lm
Contribute to ESPnet
Please follow https://github.com/espnet/espnet/blob/master/CONTRIBUTING.md to upload your pre-trained model to Hugging Face and make a pull request in the ESPnet repository.