ESPnet Speech Enhancement Demonstration

Open In Colab

This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.

Author: Chenda Li ([@LiChenda](https://github.com/LiChenda)), Wangyou Zhang ([@Emrys365](https://github.com/Emrys365))

Install

[ ]:
%pip install -q espnet==0.10.1
%pip install -q espnet_model_zoo

Speech Enhancement

Single-Channel Enhancement, the CHiME example

[ ]:
# Download one utterance from real noisy speech of CHiME4
!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav
import os

import soundfile
from IPython.display import display, Audio
mixwav_mc, sr = soundfile.read("/content/M05_440C0213_PED_REAL.wav")
# mixwav.shape: num_samples, num_channels
mixwav_sc = mixwav_mc[:,4]
display(Audio(mixwav_mc.T, rate=sr))

Download and load the pretrained Conv-Tasnet

[ ]:
!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip
!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech


separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_sc = SeparateSpeech(
  enh_train_config="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml",
  enh_model_file="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth",
  # for segment-wise process on long speech
  normalize_segment_scale=False,
  show_progressbar=True,
  ref_channel=4,
  normalize_output_wav=True,
  device="cuda:0",
)

Enhance the single-channel real noisy speech in CHiME4

[ ]:
# play the enhanced single-channel speech
wave = enh_model_sc(mixwav_sc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_sc, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))

Enhance your own pre-recordings

[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile

uploaded = files.upload()

for file_name in uploaded.keys():
  speech, rate = soundfile.read(file_name)
  assert rate == sr, "mismatch in sampling rate"
  wave = enh_model_sc(speech[None, ...], sr)
  print(f"Your input speech {file_name}", flush=True)
  display(Audio(speech, rate=sr))
  print(f"Enhanced speech for {file_name}", flush=True)
  display(Audio(wave[0].squeeze(), rate=sr))


Multi-Channel Enhancement

Download and load the pretrained mvdr neural beamformer.

[ ]:
# Download the pretained enhancement model

!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip
!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech


separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_mc = SeparateSpeech(
  enh_train_config="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml",
  enh_model_file="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth",
  # for segment-wise process on long speech
  normalize_segment_scale=False,
  show_progressbar=True,
  ref_channel=4,
  normalize_output_wav=True,
  device="cuda:0",
)

Enhance the multi-channel real noisy speech in CHiME4

[ ]:
wave = enh_model_mc(mixwav_mc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_mc.T, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))

Speech Separation

Model Selection

Please select model shown in espnet_model_zoo

In this demonstration, we will show different speech separation models on wsj0_2mix.

[ ]:
#@title Choose Speech Separation model { run: "auto" }

fs = 8000 #@param {type:"integer"}
tag = "Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave" #@param ["Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave", "Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave", "https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip"]
[ ]:
# For models uploaded to Zenodo, you can use the following python script instead:
import sys
import soundfile
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech

d = ModelDownloader()

cfg = d.download_and_unpack(tag)
separate_speech = SeparateSpeech(
  enh_train_config=cfg["train_config"],
  enh_model_file=cfg["model_file"],
  # for segment-wise process on long speech
  segment_size=2.4,
  hop_size=0.8,
  normalize_segment_scale=False,
  show_progressbar=True,
  ref_channel=None,
  normalize_output_wav=True,
  device="cuda:0",
)

Separate Speech Mixture

Separate the example in wsj0_2mix testing set

[ ]:
!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav

import os
import soundfile
from IPython.display import display, Audio

mixwav, sr = soundfile.read("447c020t_1.2106_422a0112_-1.2106.wav")
waves_wsj = separate_speech(mixwav[None, ...], fs=sr)

print("Input mixture", flush=True)
display(Audio(mixwav, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_wsj[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_wsj[1].squeeze(), rate=sr))

Separate your own recordings

[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile

uploaded = files.upload()

for file_name in uploaded.keys():
  mixwav_yours, rate = soundfile.read(file_name)
  assert rate == sr, "mismatch in sampling rate"
  waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)
  print("Input mixture", flush=True)
  display(Audio(mixwav_yours, rate=sr))
  print(f"========= Separated speech with model {tag} =========", flush=True)
  print("Separated spk1", flush=True)
  display(Audio(waves_yours[0].squeeze(), rate=sr))
  print("Separated spk2", flush=True)
  display(Audio(waves_yours[1].squeeze(), rate=sr))

Show spectrums of separated speech

[ ]:
import matplotlib.pyplot as plt
import torch
from torch_complex.tensor import ComplexTensor

from espnet.asr.asr_utils import plot_spectrogram
from espnet2.layers.stft import Stft


stft = Stft(
  n_fft=512,
  win_length=None,
  hop_length=128,
  window="hann",
)
ilens = torch.LongTensor([len(mixwav)])
# specs: (T, F)
spec_mix = ComplexTensor(
    *torch.unbind(
      stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),
      dim=-1
  )
)
spec_sep1 = ComplexTensor(
    *torch.unbind(
      stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),
      dim=-1
  )
)
spec_sep2 = ComplexTensor(
    *torch.unbind(
      stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),
      dim=-1
  )
)

# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])
# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])
samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))
plt.figure(figsize=(24, 12))
plt.subplot(3, 2, 1)
plt.title('Mixture Spectrogram')
plot_spectrogram(
  plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,
  mode='db', frame_shift=None,
  bottom=False, labelbottom=False
)
plt.subplot(3, 2, 2)
plt.title('Mixture Wavform')
plt.plot(samples, mixwav)
plt.xlim(0, len(mixwav) / sr)

plt.subplot(3, 2, 3)
plt.title('Separated Spectrogram (spk1)')
plot_spectrogram(
  plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,
  mode='db', frame_shift=None,
  bottom=False, labelbottom=False
)
plt.subplot(3, 2, 4)
plt.title('Separated Wavform (spk1)')
plt.plot(samples, waves_wsj[0].squeeze())
plt.xlim(0, len(mixwav) / sr)

plt.subplot(3, 2, 5)
plt.title('Separated Spectrogram (spk2)')
plot_spectrogram(
  plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,
  mode='db', frame_shift=None,
  bottom=False, labelbottom=False
)
plt.subplot(3, 2, 6)
plt.title('Separated Wavform (spk2)')
plt.plot(samples, waves_wsj[1].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.xlabel("Time (s)")
plt.show()

Evluate separated speech with pretrained ASR model

The ground truths are:

text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR

text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK

(This may take a while for the speech recognition.)

[ ]:
import espnet_model_zoo
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text

wsj_8k_model_url="https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1"

d = ModelDownloader()
speech2text = Speech2Text(
  **d.download_and_unpack(wsj_8k_model_url),
  device="cuda:0",
)

text_est = [None, None]
text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]
text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]
text_m, *_ = speech2text(mixwav)[0]
print("Mix Speech to Text: ", text_m)
print("Separated Speech 1 to Text: ", text_est[0])
print("Separated Speech 2 to Text: ", text_est[1])

[ ]:
import difflib
from itertools import permutations

import editdistance
import numpy as np

colors = dict(
    red=lambda text: f"\033[38;2;255;0;0m{text}\033[0m" if text else "",
    green=lambda text: f"\033[38;2;0;255;0m{text}\033[0m" if text else "",
    yellow=lambda text: f"\033[38;2;225;225;0m{text}\033[0m" if text else "",
    white=lambda text: f"\033[38;2;255;255;255m{text}\033[0m" if text else "",
    black=lambda text: f"\033[38;2;0;0;0m{text}\033[0m" if text else "",
)

def diff_strings(ref, est):
    """Reference: https://stackoverflow.com/a/64404008/7384873"""
    ref_str, est_str, err_str = [], [], []
    matcher = difflib.SequenceMatcher(None, ref, est)
    for opcode, a0, a1, b0, b1 in matcher.get_opcodes():
        if opcode == "equal":
            txt = ref[a0:a1]
            ref_str.append(txt)
            est_str.append(txt)
            err_str.append(" " * (a1 - a0))
        elif opcode == "insert":
            ref_str.append("*" * (b1 - b0))
            est_str.append(colors["green"](est[b0:b1]))
            err_str.append(colors["black"]("I" * (b1 - b0)))
        elif opcode == "delete":
            ref_str.append(ref[a0:a1])
            est_str.append(colors["red"]("*" * (a1 - a0)))
            err_str.append(colors["black"]("D" * (a1 - a0)))
        elif opcode == "replace":
            diff = a1 - a0 - b1 + b0
            if diff >= 0:
                txt_ref = ref[a0:a1]
                txt_est = colors["yellow"](est[b0:b1]) + colors["red"]("*" * diff)
                txt_err = "S" * (b1 - b0) + "D" * diff
            elif diff < 0:
                txt_ref = ref[a0:a1] + "*" * -diff
                txt_est = colors["yellow"](est[b0:b1]) + colors["green"]("*" * -diff)
                txt_err = "S" * (b1 - b0) + "I" * -diff

            ref_str.append(txt_ref)
            est_str.append(txt_est)
            err_str.append(colors["black"](txt_err))
    return "".join(ref_str), "".join(est_str), "".join(err_str)


text_ref = [
  "SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR",
  "THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK",
]

print("=====================" , flush=True)
perms = list(permutations(range(2)))
string_edit = [
  [
    editdistance.eval(text_ref[m], text_est[n])
    for m, n in enumerate(p)
  ]
  for p in perms
]

dist = [sum(edist) for edist in string_edit]
perm_idx = np.argmin(dist)
perm = perms[perm_idx]

for i, p in enumerate(perm):
  print("\n--------------- Text %d ---------------" % (i + 1), flush=True)
  ref, est, err = diff_strings(text_ref[i], text_est[p])
  print("REF: " + ref + "\n" + "HYP: " + est + "\n" + "ERR: " + err, flush=True)
  print("Edit Distance = {}\n".format(string_edit[perm_idx][i]), flush=True)