{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "ESPnet Speech Enhancement Demonstration.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "MuhqhYSToxl7" }, "source": [ "# ESPnet Speech Enhancement Demonstration\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fjRJCh96SoYLZPRxsjF9VDv4Q2VoIckI?usp=sharing)\n", "\n", "\n", "This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.\n", "\n", "- ESPnet2-SE: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/enh1\n", "\n", "Author: Chenda Li ([@LiChenda](https://github.com/LiChenda)), Wangyou Zhang ([@Emrys365](https://github.com/Emrys365))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aVqAfa2GiUG4" }, "source": [ "## Install" ] }, { "cell_type": "code", "metadata": { "id": "LDMeuhHZg7aL" }, "source": [ "%pip install -q espnet==0.10.1\n", "%pip install -q espnet_model_zoo" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "4-MWmiJrarUD" }, "source": [ "## Speech Enhancement" ] }, { "cell_type": "markdown", "metadata": { "id": "qcWmAw6GnP-g" }, "source": [ "### Single-Channel Enhancement, the CHiME example\n" ] }, { "cell_type": "code", "metadata": { "id": "FfU84I7GnZaQ" }, "source": [ "# Download one utterance from real noisy speech of CHiME4\n", "!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav\n", "import os\n", "\n", "import soundfile\n", "from IPython.display import display, Audio\n", "mixwav_mc, sr = soundfile.read(\"/content/M05_440C0213_PED_REAL.wav\")\n", "# mixwav.shape: num_samples, num_channels\n", "mixwav_sc = mixwav_mc[:,4]\n", "display(Audio(mixwav_mc.T, rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "GH5CUnsuxA6B" }, "source": [ "#### Download and load the pretrained Conv-Tasnet\n" ] }, { "cell_type": "code", "metadata": { "id": "5ScWB5qaxBmN" }, "source": [ "!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip\n", "!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "YfZgR81AxemW" }, "source": [ "# Load the model\n", "# If you encounter error \"No module named 'espnet2'\", please re-run the 1st Cell. This might be a colab bug.\n", "import sys\n", "import soundfile\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "\n", "separate_speech = {}\n", "# For models downloaded from GoogleDrive, you can use the following script:\n", "enh_model_sc = SeparateSpeech(\n", " enh_train_config=\"/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml\",\n", " enh_model_file=\"/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth\",\n", " # for segment-wise process on long speech\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=4,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "I4Pysnfu0ySo" }, "source": [ "#### Enhance the single-channel real noisy speech in CHiME4\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "ZC7ott8N0InK" }, "source": [ "# play the enhanced single-channel speech\n", "wave = enh_model_sc(mixwav_sc[None, ...], sr)\n", "print(\"Input real noisy speech\", flush=True)\n", "display(Audio(mixwav_sc, rate=sr))\n", "print(\"Enhanced speech\", flush=True)\n", "display(Audio(wave[0].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "apykW9eM03SC" }, "source": [ "### Enhance your own pre-recordings\n" ] }, { "cell_type": "code", "metadata": { "id": "FAbaNE4K02z8" }, "source": [ "from google.colab import files\n", "from IPython.display import display, Audio\n", "import soundfile\n", "\n", "uploaded = files.upload()\n", "\n", "for file_name in uploaded.keys():\n", " speech, rate = soundfile.read(file_name)\n", " assert rate == sr, \"mismatch in sampling rate\"\n", " wave = enh_model_sc(speech[None, ...], sr)\n", " print(f\"Your input speech {file_name}\", flush=True)\n", " display(Audio(speech, rate=sr))\n", " print(f\"Enhanced speech for {file_name}\", flush=True)\n", " display(Audio(wave[0].squeeze(), rate=sr))\n", "\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "yLBUR9r9hBSe" }, "source": [ "### Multi-Channel Enhancement\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xtd1vNmYf7S8" }, "source": [ "#### Download and load the pretrained mvdr neural beamformer.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "Zs-ckGZsfku3" }, "source": [ "# Download the pretained enhancement model\n", "\n", "!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip\n", "!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc " ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "1ZmOjl_kiJCH" }, "source": [ "# Load the model\n", "# If you encounter error \"No module named 'espnet2'\", please re-run the 1st Cell. This might be a colab bug.\n", "import sys\n", "import soundfile\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "\n", "separate_speech = {}\n", "# For models downloaded from GoogleDrive, you can use the following script:\n", "enh_model_mc = SeparateSpeech(\n", " enh_train_config=\"/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml\",\n", " enh_model_file=\"/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth\",\n", " # for segment-wise process on long speech\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=4,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "xfcY-Yj5lH7X" }, "source": [ "#### Enhance the multi-channel real noisy speech in CHiME4\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "0kIjHfagi4T1" }, "source": [ "wave = enh_model_mc(mixwav_mc[None, ...], sr)\n", "print(\"Input real noisy speech\", flush=True)\n", "display(Audio(mixwav_mc.T, rate=sr))\n", "print(\"Enhanced speech\", flush=True)\n", "display(Audio(wave[0].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Zwz0Iu2ZT2vd" }, "source": [ "## Speech Separation" ] }, { "cell_type": "markdown", "metadata": { "id": "cQAiGZyPhAko" }, "source": [ "\n", "### Model Selection\n", "\n", "Please select model shown in [espnet_model_zoo](https://github.com/espnet/espnet_model_zoo/blob/master/espnet_model_zoo/table.csv)\n", "\n", "In this demonstration, we will show different speech separation models on wsj0_2mix.\n" ] }, { "cell_type": "code", "metadata": { "id": "9v9A8Ge4PMA4" }, "source": [ "#@title Choose Speech Separation model { run: \"auto\" }\n", "\n", "fs = 8000 #@param {type:\"integer\"}\n", "tag = \"Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave\" #@param [\"Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave\", \"Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave\", \"https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip\"]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "NlGegk63dE3u" }, "source": [ "# For models uploaded to Zenodo, you can use the following python script instead:\n", "import sys\n", "import soundfile\n", "from espnet_model_zoo.downloader import ModelDownloader\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "d = ModelDownloader()\n", "\n", "cfg = d.download_and_unpack(tag)\n", "separate_speech = SeparateSpeech(\n", " enh_train_config=cfg[\"train_config\"],\n", " enh_model_file=cfg[\"model_file\"],\n", " # for segment-wise process on long speech\n", " segment_size=2.4,\n", " hop_size=0.8,\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=None,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "iOtdLFrng8DW" }, "source": [ "### Separate Speech Mixture\n" ] }, { "cell_type": "markdown", "metadata": { "id": "OPIO8rbihD4T" }, "source": [ "#### Separate the example in wsj0_2mix testing set" ] }, { "cell_type": "code", "metadata": { "id": "WSXQ39LIjYZl" }, "source": [ "!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav\n", "\n", "import os\n", "import soundfile\n", "from IPython.display import display, Audio\n", "\n", "mixwav, sr = soundfile.read(\"447c020t_1.2106_422a0112_-1.2106.wav\")\n", "waves_wsj = separate_speech(mixwav[None, ...], fs=sr)\n", "\n", "print(\"Input mixture\", flush=True)\n", "display(Audio(mixwav, rate=sr))\n", "print(f\"========= Separated speech with model {tag} =========\", flush=True)\n", "print(\"Separated spk1\", flush=True)\n", "display(Audio(waves_wsj[0].squeeze(), rate=sr))\n", "print(\"Separated spk2\", flush=True)\n", "display(Audio(waves_wsj[1].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "3B_DTVV8oJZz" }, "source": [ "#### Separate your own recordings" ] }, { "cell_type": "code", "metadata": { "id": "ZcdPWN4MocBd" }, "source": [ "from google.colab import files\n", "from IPython.display import display, Audio\n", "import soundfile\n", "\n", "uploaded = files.upload()\n", "\n", "for file_name in uploaded.keys():\n", " mixwav_yours, rate = soundfile.read(file_name)\n", " assert rate == sr, \"mismatch in sampling rate\"\n", " waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)\n", " print(\"Input mixture\", flush=True)\n", " display(Audio(mixwav_yours, rate=sr))\n", " print(f\"========= Separated speech with model {tag} =========\", flush=True)\n", " print(\"Separated spk1\", flush=True)\n", " display(Audio(waves_yours[0].squeeze(), rate=sr))\n", " print(\"Separated spk2\", flush=True)\n", " display(Audio(waves_yours[1].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "A-LrSAKPgzSZ" }, "source": [ "#### Show spectrums of separated speech" ] }, { "cell_type": "code", "metadata": { "id": "MaYTZqXtGJEJ" }, "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "from torch_complex.tensor import ComplexTensor\n", "\n", "from espnet.asr.asr_utils import plot_spectrogram\n", "from espnet2.layers.stft import Stft\n", "\n", "\n", "stft = Stft(\n", " n_fft=512,\n", " win_length=None,\n", " hop_length=128,\n", " window=\"hann\",\n", ")\n", "ilens = torch.LongTensor([len(mixwav)])\n", "# specs: (T, F)\n", "spec_mix = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "spec_sep1 = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "spec_sep2 = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "\n", "# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])\n", "# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])\n", "samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))\n", "plt.figure(figsize=(24, 12))\n", "plt.subplot(3, 2, 1)\n", "plt.title('Mixture Spectrogram')\n", "plot_spectrogram(\n", " plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 2)\n", "plt.title('Mixture Wavform')\n", "plt.plot(samples, mixwav)\n", "plt.xlim(0, len(mixwav) / sr)\n", "\n", "plt.subplot(3, 2, 3)\n", "plt.title('Separated Spectrogram (spk1)')\n", "plot_spectrogram(\n", " plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 4)\n", "plt.title('Separated Wavform (spk1)')\n", "plt.plot(samples, waves_wsj[0].squeeze())\n", "plt.xlim(0, len(mixwav) / sr)\n", "\n", "plt.subplot(3, 2, 5)\n", "plt.title('Separated Spectrogram (spk2)')\n", "plot_spectrogram(\n", " plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 6)\n", "plt.title('Separated Wavform (spk2)')\n", "plt.plot(samples, waves_wsj[1].squeeze())\n", "plt.xlim(0, len(mixwav) / sr)\n", "plt.xlabel(\"Time (s)\")\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BnuxcpSicZsx" }, "source": [ "## Evluate separated speech with pretrained ASR model\n", "\n", "The ground truths are:\n", "\n", "`text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR`\n", "\n", "`text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK`\n", "\n", "(This may take a while for the speech recognition.)" ] }, { "cell_type": "code", "metadata": { "id": "SZMUsNEQUAEU" }, "source": [ "import espnet_model_zoo\n", "from espnet_model_zoo.downloader import ModelDownloader\n", "from espnet2.bin.asr_inference import Speech2Text\n", "\n", "wsj_8k_model_url=\"https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1\"\n", "\n", "d = ModelDownloader()\n", "speech2text = Speech2Text(\n", " **d.download_and_unpack(wsj_8k_model_url),\n", " device=\"cuda:0\",\n", ")\n", "\n", "text_est = [None, None]\n", "text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]\n", "text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]\n", "text_m, *_ = speech2text(mixwav)[0]\n", "print(\"Mix Speech to Text: \", text_m)\n", "print(\"Separated Speech 1 to Text: \", text_est[0])\n", "print(\"Separated Speech 2 to Text: \", text_est[1])\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GrAbO58MXrji" }, "source": [ "import difflib\n", "from itertools import permutations\n", "\n", "import editdistance\n", "import numpy as np\n", "\n", "colors = dict(\n", " red=lambda text: f\"\\033[38;2;255;0;0m{text}\\033[0m\" if text else \"\",\n", " green=lambda text: f\"\\033[38;2;0;255;0m{text}\\033[0m\" if text else \"\",\n", " yellow=lambda text: f\"\\033[38;2;225;225;0m{text}\\033[0m\" if text else \"\",\n", " white=lambda text: f\"\\033[38;2;255;255;255m{text}\\033[0m\" if text else \"\",\n", " black=lambda text: f\"\\033[38;2;0;0;0m{text}\\033[0m\" if text else \"\",\n", ")\n", "\n", "def diff_strings(ref, est):\n", " \"\"\"Reference: https://stackoverflow.com/a/64404008/7384873\"\"\"\n", " ref_str, est_str, err_str = [], [], []\n", " matcher = difflib.SequenceMatcher(None, ref, est)\n", " for opcode, a0, a1, b0, b1 in matcher.get_opcodes():\n", " if opcode == \"equal\":\n", " txt = ref[a0:a1]\n", " ref_str.append(txt)\n", " est_str.append(txt)\n", " err_str.append(\" \" * (a1 - a0))\n", " elif opcode == \"insert\":\n", " ref_str.append(\"*\" * (b1 - b0))\n", " est_str.append(colors[\"green\"](est[b0:b1]))\n", " err_str.append(colors[\"black\"](\"I\" * (b1 - b0)))\n", " elif opcode == \"delete\":\n", " ref_str.append(ref[a0:a1])\n", " est_str.append(colors[\"red\"](\"*\" * (a1 - a0)))\n", " err_str.append(colors[\"black\"](\"D\" * (a1 - a0)))\n", " elif opcode == \"replace\":\n", " diff = a1 - a0 - b1 + b0\n", " if diff >= 0:\n", " txt_ref = ref[a0:a1]\n", " txt_est = colors[\"yellow\"](est[b0:b1]) + colors[\"red\"](\"*\" * diff)\n", " txt_err = \"S\" * (b1 - b0) + \"D\" * diff\n", " elif diff < 0:\n", " txt_ref = ref[a0:a1] + \"*\" * -diff\n", " txt_est = colors[\"yellow\"](est[b0:b1]) + colors[\"green\"](\"*\" * -diff)\n", " txt_err = \"S\" * (b1 - b0) + \"I\" * -diff\n", "\n", " ref_str.append(txt_ref)\n", " est_str.append(txt_est)\n", " err_str.append(colors[\"black\"](txt_err))\n", " return \"\".join(ref_str), \"\".join(est_str), \"\".join(err_str)\n", "\n", "\n", "text_ref = [\n", " \"SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR\",\n", " \"THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK\",\n", "]\n", "\n", "print(\"=====================\" , flush=True)\n", "perms = list(permutations(range(2)))\n", "string_edit = [\n", " [\n", " editdistance.eval(text_ref[m], text_est[n])\n", " for m, n in enumerate(p)\n", " ]\n", " for p in perms\n", "]\n", "\n", "dist = [sum(edist) for edist in string_edit]\n", "perm_idx = np.argmin(dist)\n", "perm = perms[perm_idx]\n", "\n", "for i, p in enumerate(perm):\n", " print(\"\\n--------------- Text %d ---------------\" % (i + 1), flush=True)\n", " ref, est, err = diff_strings(text_ref[i], text_est[p])\n", " print(\"REF: \" + ref + \"\\n\" + \"HYP: \" + est + \"\\n\" + \"ERR: \" + err, flush=True)\n", " print(\"Edit Distance = {}\\n\".format(string_edit[perm_idx][i]), flush=True)" ], "execution_count": null, "outputs": [] } ] }