ESPnet Speech Enhancement Demonstration
ESPnet Speech Enhancement Demonstration
This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.
- ESPnet2-SE: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/enh1
Presenters:
- Shinji Watanabe (shinjiw@cmu.edu)
- Chenda Li (lichenda1996@sjtu.edu.cn)
- Jing Shi (shijing2014@ia.ac.cn)
- Wangyou Zhang (wyz-97@sjtu.edu.cn)
- Yen-Ju Lu (neil.lu@citi.sinica.edu.tw)
This notebook is created by: Chenda Li (@LiChenda) and Wangyou Zhang (@Emrys365)
Contents
(1) Tutorials on the Basic Usage
Install
Speech Enhancement with Pretrained Models
We support various interfaces, e.g. Python API, HuggingFace API, portable speech enhancement scripts for other tasks, etc.
2.1 Single-channel Enhancement (CHiME-4)
2.2 Enhance Your Own Recordings
2.3 Multi-channel Enhancement (CHiME-4)
- Speech Separation with Pretrained Models
3.1 Model Selection
3.2 Separate Speech Mixture
- Evaluate Separated Speech with the Pretrained ASR Model
(2) Tutorials for Adding New Recipe and Contributing to ESPnet-SE Project
Creating a New Recipe
Implementing a New Speech Enhancement/Separation Model
(1) Tutorials on the Basic Usage
Install
%pip install -q espnet==0.10.1
%pip install -q espnet_model_zoo
Speech Enhancement with Pretrained Models
Single-Channel Enhancement, the CHiME example
# Download one utterance from real noisy speech of CHiME4
!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav_mc, sr = soundfile.read("/content/M05_440C0213_PED_REAL.wav")
# mixwav.shape: num_samples, num_channels
mixwav_sc = mixwav_mc[:,4]
display(Audio(mixwav_mc.T, rate=sr))
Download and load the pretrained Conv-Tasnet
!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip
!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_sc = SeparateSpeech(
train_config="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml",
model_file="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the single-channel real noisy speech in CHiME4
# play the enhanced single-channel speech
wave = enh_model_sc(mixwav_sc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_sc, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Enhance your own pre-recordings
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
speech, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
wave = enh_model_sc(speech[None, ...], sr)
print(f"Your input speech {file_name}", flush=True)
display(Audio(speech, rate=sr))
print(f"Enhanced speech for {file_name}", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Multi-Channel Enhancement
Download and load the pretrained mvdr neural beamformer.
# Download the pretained enhancement model
!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip
!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_mc = SeparateSpeech(
train_config="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml",
model_file="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the multi-channel real noisy speech in CHiME4
wave = enh_model_mc(mixwav_mc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_mc.T, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Portable speech enhancement scripts for other tasks
For an ESPNet ASR or TTS dataset like below:
data
`-- et05_real_isolated_6ch_track
|-- spk2utt
|-- text
|-- utt2spk
|-- utt2uniq
`-- wav.scp
Run the following scripts to create an enhanced dataset:
scripts/utils/enhance_dataset.sh \
--spk_num 1 \
--gpu_inference true \
--inference_nj 4 \
--fs 16k \
--id_prefix "" \
dump/raw/et05_real_isolated_6ch_track \
data/et05_real_isolated_6ch_track_enh \
exp/enh_train_enh_beamformer_mvdr_raw/valid.loss.best.pth
The above script will generate a new directory data/et05_real_isolated_6ch_track_enh:
data
`-- et05_real_isolated_6ch_track_enh
|-- spk2utt
|-- text
|-- utt2spk
|-- utt2uniq
|-- wav.scp
`-- wavs/
where wav.scp contains paths to the enhanced audios (stored in wavs/).
Speech Separation
Model Selection
In this demonstration, we will show different speech separation models on wsj0_2mix.
The pretrained models can be download from direct URL, or from zenodo and huggingface with model ID.
#@title Choose Speech Separation model { run: "auto" }
fs = 8000 #@param {type:"integer"}
tag = "espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave" #@param ["Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave", "Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave", "https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip", "espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave"]
# For models uploaded to Zenodo, you can use the following python script instead:
import sys
import soundfile
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech
d = ModelDownloader()
cfg = d.download_and_unpack(tag)
separate_speech = SeparateSpeech(
enh_train_config=cfg["train_config"],
enh_model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=None,
normalize_output_wav=True,
device="cuda:0",
)
Separate Speech Mixture
Separate the example in wsj0_2mix testing set
!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav, sr = soundfile.read("447c020t_1.2106_422a0112_-1.2106.wav")
waves_wsj = separate_speech(mixwav[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_wsj[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_wsj[1].squeeze(), rate=sr))
Separate your own recordings
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
mixwav_yours, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav_yours, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_yours[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_yours[1].squeeze(), rate=sr))
Show spectrums of separated speech
import matplotlib.pyplot as plt
import torch
from torch_complex.tensor import ComplexTensor
from espnet.asr.asr_utils import plot_spectrogram
from espnet2.layers.stft import Stft
stft = Stft(
n_fft=512,
win_length=None,
hop_length=128,
window="hann",
)
ilens = torch.LongTensor([len(mixwav)])
# specs: (T, F)
spec_mix = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep1 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep2 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),
dim=-1
)
)
# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])
# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])
samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))
plt.figure(figsize=(24, 12))
plt.subplot(3, 2, 1)
plt.title('Mixture Spectrogram')
plot_spectrogram(
plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 2)
plt.title('Mixture Wavform')
plt.plot(samples, mixwav)
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 3)
plt.title('Separated Spectrogram (spk1)')
plot_spectrogram(
plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 4)
plt.title('Separated Wavform (spk1)')
plt.plot(samples, waves_wsj[0].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 5)
plt.title('Separated Spectrogram (spk2)')
plot_spectrogram(
plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 6)
plt.title('Separated Wavform (spk2)')
plt.plot(samples, waves_wsj[1].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.xlabel("Time (s)")
plt.show()
Evluate separated speech with pretrained ASR model
The ground truths are:
text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR
text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK
(This may take a while for the speech recognition.)
%pip install -q https://github.com/kpu/kenlm/archive/master.zip # ASR need kenlm
import espnet_model_zoo
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text
wsj_8k_model_url="https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1"
d = ModelDownloader()
speech2text = Speech2Text(
**d.download_and_unpack(wsj_8k_model_url),
device="cuda:0",
)
text_est = [None, None]
text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]
text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]
text_m, *_ = speech2text(mixwav)[0]
print("Mix Speech to Text: ", text_m)
print("Separated Speech 1 to Text: ", text_est[0])
print("Separated Speech 2 to Text: ", text_est[1])
import difflib
from itertools import permutations
import editdistance
import numpy as np
colors = dict(
red=lambda text: f"\033[38;2;255;0;0m{text}\033[0m" if text else "",
green=lambda text: f"\033[38;2;0;255;0m{text}\033[0m" if text else "",
yellow=lambda text: f"\033[38;2;225;225;0m{text}\033[0m" if text else "",
white=lambda text: f"\033[38;2;255;255;255m{text}\033[0m" if text else "",
black=lambda text: f"\033[38;2;0;0;0m{text}\033[0m" if text else "",
)
def diff_strings(ref, est):
"""Reference: https://stackoverflow.com/a/64404008/7384873"""
ref_str, est_str, err_str = [], [], []
matcher = difflib.SequenceMatcher(None, ref, est)
for opcode, a0, a1, b0, b1 in matcher.get_opcodes():
if opcode == "equal":
txt = ref[a0:a1]
ref_str.append(txt)
est_str.append(txt)
err_str.append(" " * (a1 - a0))
elif opcode == "insert":
ref_str.append("*" * (b1 - b0))
est_str.append(colors["green"](est[b0:b1]))
err_str.append(colors["black"]("I" * (b1 - b0)))
elif opcode == "delete":
ref_str.append(ref[a0:a1])
est_str.append(colors["red"]("*" * (a1 - a0)))
err_str.append(colors["black"]("D" * (a1 - a0)))
elif opcode == "replace":
diff = a1 - a0 - b1 + b0
if diff >= 0:
txt_ref = ref[a0:a1]
txt_est = colors["yellow"](est[b0:b1]) + colors["red"]("*" * diff)
txt_err = "S" * (b1 - b0) + "D" * diff
elif diff < 0:
txt_ref = ref[a0:a1] + "*" * -diff
txt_est = colors["yellow"](est[b0:b1]) + colors["green"]("*" * -diff)
txt_err = "S" * (b1 - b0) + "I" * -diff
ref_str.append(txt_ref)
est_str.append(txt_est)
err_str.append(colors["black"](txt_err))
return "".join(ref_str), "".join(est_str), "".join(err_str)
text_ref = [
"SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR",
"THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK",
]
print("=====================" , flush=True)
perms = list(permutations(range(2)))
string_edit = [
[
editdistance.eval(text_ref[m], text_est[n])
for m, n in enumerate(p)
]
for p in perms
]
dist = [sum(edist) for edist in string_edit]
perm_idx = np.argmin(dist)
perm = perms[perm_idx]
for i, p in enumerate(perm):
print("\n--------------- Text %d ---------------" % (i + 1), flush=True)
ref, est, err = diff_strings(text_ref[i], text_est[p])
print("REF: " + ref + "\n" + "HYP: " + est + "\n" + "ERR: " + err, flush=True)
print("Edit Distance = {}\n".format(string_edit[perm_idx][i]), flush=True)
(2) Tutorials on Contributing to ESPNet-SE Project
If you would like to contribute to the ESPnet-SE project, or if you would like to make modifications based on the current speech enhancement/separation functionality, the following tutorials will provide you detailed information about how to creating new recipes or new models in ESPnet-SE.
Creating a New Recipe
Step 1 Create recipe directory
First, run the following command to create the directory for the new recipe from our template:
egs2/TEMPLATE/enh1/setup.sh egs2/<your-recipe-name>/enh1
For the following steps, we assume the operations are done under the directory
egs2/<your-recipe-name>/enh1/
.
Step 2 Write scripts for data preparation
Prepare local/data.sh
, which will be used in stage 1 in enh.sh
. It can take some arguments as input, see egs2/wsj0_2mix/enh1/local/data.sh for reference.
The script local/data.sh
should finally generate Kaldi-style data directories under <recipe-dir>/data/
. Each subset directory should contains at least 4 files:
<recipe-dir>/data/<subset-name>/
├── spk{1,2,3...}.scp (clean speech references)
├── spk2utt
├── utt2spk
└── wav.scp (noisy speech)
Optionally, it can also contain noise{}.scp
and dereverb{}.scp
, which point to the corresponding noise and dereverberated references respectively. {} can be 1, 2, ..., depending on the number of noise types (dereverberated signals) in the input signal in wav.scp
.
Make sure to sort the scp and other related files as in Kaldi. Also, remember to run . ./path.sh
in local/data.sh
before sorting, because it will force sorting to be byte-wise, i.e. export LC_ALL=C
.
Remember to check your new scripts with shellcheck, otherwise they may fail the tests in ci/test_shell.sh.
Step 3 Prepare training configuration
Prepare training configuration files (e.g. train.yaml) under conf/
.
If you have multiple configuration files, it is recommended to put them under
conf/tuning/
, and create a symbolic linkconf/tuning/train.yaml
pointing to the config file with the best performance.
Step 4 Prepare run.sh
Write run.sh
to provide a template entry script, so that users can easily run your recipe by ./run.sh
. See egs2/wsj0_2mix/enh1/run.sh for reference.
If your recipes provide references for noise and/or dereverberation, you can add the argument
--use_noise_ref true
and/or--use_dereverb_ref true
inrun.sh
.
Implementing a New Speech Enhancement/Separation Model
The current ESPnet-SE tool adopts an encoder-separator-decoder architecture for all models, e.g.
For Time-Frequency masking models, the encoder and decoder would be stft_encoder.py and stft_decoder.py respectively, and the separator can be any of dprnn_separator.py, rnn_separator.py, tcn_separator.py, and transformer_separator.py. For TasNet, the encoder and decoder are conv_encoder.py and conv_decoder.py respectively. The separator is tcn_separator.py.
Step 1 Create model scripts
For encoder, separator, and decoder models, create new scripts under espnet2/enh/encoder/, espnet2/enh/separator/, and espnet2/enh/decoder/, respectively.
For a separator model, please make sure it implements the num_spk
property. See espnet2/enh/separator/rnn_separator.py for reference.
Remember to format your new scripts to match the styles in
black
andflake8
, otherwise they may fail the tests in ci/test_python.sh.
Step 2 Add the new model to related scripts
In espnet2/tasks/enh.py, add your new model to the corresponding ClassChoices
, e.g.
- For encoders, add
<key>=<your-model>
toencoder_choices
. - For decoders, add
<key>=<your-model>
todecoder_choices
. - For separators, add
<key>=<your-model>
toseparator_choices
.
Step 3 [Optional] Create new loss functions
If you want to use a new loss function for your model, you can add it to espnet2/enh/espnet_model.py, such as:
@staticmethod
def new_loss(ref, inf):
"""Your new loss
Args:
ref: (Batch, samples)
inf: (Batch, samples)
Returns:
loss: (Batch,)
"""
...
return loss
Then add your loss name to ALL_LOSS_TYPES, and handle the loss calculation in _compute_loss.
Step 4 Create unit tests for the new model
Finally, it would be nice to make some unit tests for your new model under test/espnet2/enh/encoder, test/espnet2/enh/decoder, or test/espnet2/enh/separator.