{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "ESPnet SE Demonstration for WASPAA 2021.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "MuhqhYSToxl7" }, "source": [ "# ESPnet Speech Enhancement Demonstration\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fjRJCh96SoYLZPRxsjF9VDv4Q2VoIckI?usp=sharing)\n", "\n", "\n", "\n", "\n", "This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.\n", "\n", "- ESPnet2-SE: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/enh1\n", "\n", "Presenters:\n", "\n", "\n", "* Shinji Watanabe (shinjiw@cmu.edu)\n", "* Chenda Li (lichenda1996@sjtu.edu.cn)\n", "* Jing Shi (shijing2014@ia.ac.cn)\n", "* Wangyou Zhang (wyz-97@sjtu.edu.cn)\n", "* Yen-Ju Lu (neil.lu@citi.sinica.edu.tw)\n", "\n", "\n", "This notebook is created by: Chenda Li ([@LiChenda](https://github.com/LiChenda)) and Wangyou Zhang ([@Emrys365](https://github.com/Emrys365))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AQdm0Fbd0D2z" }, "source": [ "# Contents\n", "\n", "(1) Tutorials on the Basic Usage\n", "\n", "1. Install\n", "\n", "2. Speech Enhancement with Pretrained Models\n", "\n", " > We support various interfaces, e.g. Python API, HuggingFace API, portable speech enhancement scripts for other tasks, etc.\n", "\n", " 2.1 Single-channel Enhancement (CHiME-4)\n", "\n", " 2.2 Enhance Your Own Recordings\n", "\n", " 2.3 Multi-channel Enhancement (CHiME-4)\n", "\n", "3. Speech Separation with Pretrained Models\n", "\n", " 3.1 Model Selection\n", " \n", " 3.2 Separate Speech Mixture\n", "\n", "4. Evaluate Separated Speech with the Pretrained ASR Model\n", "\n", "(2) Tutorials for Adding New Recipe and Contributing to ESPnet-SE Project\n", "\n", "1. Creating a New Recipe\n", "\n", "2. Implementing a New Speech Enhancement/Separation Model\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ud_F-tlx2XYb" }, "source": [ "# (1) Tutorials on the Basic Usage" ] }, { "cell_type": "markdown", "metadata": { "id": "aVqAfa2GiUG4" }, "source": [ "## Install" ] }, { "cell_type": "code", "metadata": { "id": "LDMeuhHZg7aL" }, "source": [ "%pip install -q espnet==0.10.1\n", "%pip install -q espnet_model_zoo\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "4-MWmiJrarUD" }, "source": [ "## Speech Enhancement with Pretrained Models" ] }, { "cell_type": "markdown", "metadata": { "id": "qcWmAw6GnP-g" }, "source": [ "### Single-Channel Enhancement, the CHiME example\n" ] }, { "cell_type": "code", "metadata": { "id": "FfU84I7GnZaQ" }, "source": [ "# Download one utterance from real noisy speech of CHiME4\n", "!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav\n", "import os\n", "\n", "import soundfile\n", "from IPython.display import display, Audio\n", "mixwav_mc, sr = soundfile.read(\"/content/M05_440C0213_PED_REAL.wav\")\n", "# mixwav.shape: num_samples, num_channels\n", "mixwav_sc = mixwav_mc[:,4]\n", "display(Audio(mixwav_mc.T, rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "GH5CUnsuxA6B" }, "source": [ "#### Download and load the pretrained Conv-Tasnet\n" ] }, { "cell_type": "code", "metadata": { "id": "5ScWB5qaxBmN" }, "source": [ "!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip\n", "!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "YfZgR81AxemW" }, "source": [ "# Load the model\n", "# If you encounter error \"No module named 'espnet2'\", please re-run the 1st Cell. This might be a colab bug.\n", "import sys\n", "import soundfile\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "\n", "separate_speech = {}\n", "# For models downloaded from GoogleDrive, you can use the following script:\n", "enh_model_sc = SeparateSpeech(\n", " train_config=\"/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml\",\n", " model_file=\"/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth\",\n", " # for segment-wise process on long speech\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=4,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "I4Pysnfu0ySo" }, "source": [ "#### Enhance the single-channel real noisy speech in CHiME4\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "ZC7ott8N0InK" }, "source": [ "# play the enhanced single-channel speech\n", "wave = enh_model_sc(mixwav_sc[None, ...], sr)\n", "\n", "print(\"Input real noisy speech\", flush=True)\n", "display(Audio(mixwav_sc, rate=sr))\n", "print(\"Enhanced speech\", flush=True)\n", "display(Audio(wave[0].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "apykW9eM03SC" }, "source": [ "### Enhance your own pre-recordings\n" ] }, { "cell_type": "code", "metadata": { "id": "FAbaNE4K02z8" }, "source": [ "from google.colab import files\n", "from IPython.display import display, Audio\n", "import soundfile\n", "\n", "uploaded = files.upload()\n", "\n", "for file_name in uploaded.keys():\n", " speech, rate = soundfile.read(file_name)\n", " assert rate == sr, \"mismatch in sampling rate\"\n", " wave = enh_model_sc(speech[None, ...], sr)\n", " print(f\"Your input speech {file_name}\", flush=True)\n", " display(Audio(speech, rate=sr))\n", " print(f\"Enhanced speech for {file_name}\", flush=True)\n", " display(Audio(wave[0].squeeze(), rate=sr))\n", "\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "yLBUR9r9hBSe" }, "source": [ "### Multi-Channel Enhancement\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xtd1vNmYf7S8" }, "source": [ "#### Download and load the pretrained mvdr neural beamformer.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "Zs-ckGZsfku3" }, "source": [ "# Download the pretained enhancement model\n", "\n", "!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip\n", "!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc " ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "1ZmOjl_kiJCH" }, "source": [ "# Load the model\n", "# If you encounter error \"No module named 'espnet2'\", please re-run the 1st Cell. This might be a colab bug.\n", "import sys\n", "import soundfile\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "\n", "separate_speech = {}\n", "# For models downloaded from GoogleDrive, you can use the following script:\n", "enh_model_mc = SeparateSpeech(\n", " train_config=\"/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml\",\n", " model_file=\"/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth\",\n", " # for segment-wise process on long speech\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=4,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "xfcY-Yj5lH7X" }, "source": [ "#### Enhance the multi-channel real noisy speech in CHiME4\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "0kIjHfagi4T1" }, "source": [ "wave = enh_model_mc(mixwav_mc[None, ...], sr)\n", "print(\"Input real noisy speech\", flush=True)\n", "display(Audio(mixwav_mc.T, rate=sr))\n", "print(\"Enhanced speech\", flush=True)\n", "display(Audio(wave[0].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "NTYVOwkt6E-R" }, "source": [ "#### Portable speech enhancement scripts for other tasks\n", "\n", "For an ESPNet ASR or TTS dataset like below:\n", "\n", "```\n", "data\n", "`-- et05_real_isolated_6ch_track\n", " |-- spk2utt\n", " |-- text\n", " |-- utt2spk\n", " |-- utt2uniq\n", " `-- wav.scp\n", "```\n", "\n", "Run the following scripts to create an enhanced dataset:\n", "\n", "```\n", "scripts/utils/enhance_dataset.sh \\\n", " --spk_num 1 \\\n", " --gpu_inference true \\\n", " --inference_nj 4 \\\n", " --fs 16k \\\n", " --id_prefix \"\" \\\n", " dump/raw/et05_real_isolated_6ch_track \\\n", " data/et05_real_isolated_6ch_track_enh \\\n", " exp/enh_train_enh_beamformer_mvdr_raw/valid.loss.best.pth\n", "```\n", "\n", "The above script will generate a new directory data/et05_real_isolated_6ch_track_enh:\n", "\n", "```\n", "data\n", "`-- et05_real_isolated_6ch_track_enh\n", " |-- spk2utt\n", " |-- text\n", " |-- utt2spk\n", " |-- utt2uniq\n", " |-- wav.scp\n", " `-- wavs/\n", "```\n", "where wav.scp contains paths to the enhanced audios (stored in wavs/).\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Zwz0Iu2ZT2vd" }, "source": [ "## Speech Separation" ] }, { "cell_type": "markdown", "metadata": { "id": "cQAiGZyPhAko" }, "source": [ "\n", "### Model Selection\n", "\n", "In this demonstration, we will show different speech separation models on wsj0_2mix.\n", "\n", "The pretrained models can be download from direct URL, or from [zenodo](https://zenodo.org/) and [huggingface](https://huggingface.co/) with model ID.\n" ] }, { "cell_type": "code", "metadata": { "id": "9v9A8Ge4PMA4" }, "source": [ "#@title Choose Speech Separation model { run: \"auto\" }\n", "\n", "fs = 8000 #@param {type:\"integer\"}\n", "tag = \"espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave\" #@param [\"Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave\", \"Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave\", \"https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip\", \"espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave\"]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "NlGegk63dE3u" }, "source": [ "# For models uploaded to Zenodo, you can use the following python script instead:\n", "import sys\n", "import soundfile\n", "from espnet_model_zoo.downloader import ModelDownloader\n", "from espnet2.bin.enh_inference import SeparateSpeech\n", "\n", "d = ModelDownloader()\n", "\n", "cfg = d.download_and_unpack(tag)\n", "separate_speech = SeparateSpeech(\n", " enh_train_config=cfg[\"train_config\"],\n", " enh_model_file=cfg[\"model_file\"],\n", " # for segment-wise process on long speech\n", " segment_size=2.4,\n", " hop_size=0.8,\n", " normalize_segment_scale=False,\n", " show_progressbar=True,\n", " ref_channel=None,\n", " normalize_output_wav=True,\n", " device=\"cuda:0\",\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "iOtdLFrng8DW" }, "source": [ "### Separate Speech Mixture\n" ] }, { "cell_type": "markdown", "metadata": { "id": "OPIO8rbihD4T" }, "source": [ "#### Separate the example in wsj0_2mix testing set" ] }, { "cell_type": "code", "metadata": { "id": "WSXQ39LIjYZl" }, "source": [ "!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav\n", "\n", "import os\n", "import soundfile\n", "from IPython.display import display, Audio\n", "\n", "mixwav, sr = soundfile.read(\"447c020t_1.2106_422a0112_-1.2106.wav\")\n", "waves_wsj = separate_speech(mixwav[None, ...], fs=sr)\n", "\n", "print(\"Input mixture\", flush=True)\n", "display(Audio(mixwav, rate=sr))\n", "print(f\"========= Separated speech with model {tag} =========\", flush=True)\n", "print(\"Separated spk1\", flush=True)\n", "display(Audio(waves_wsj[0].squeeze(), rate=sr))\n", "print(\"Separated spk2\", flush=True)\n", "display(Audio(waves_wsj[1].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "3B_DTVV8oJZz" }, "source": [ "#### Separate your own recordings" ] }, { "cell_type": "code", "metadata": { "id": "ZcdPWN4MocBd" }, "source": [ "from google.colab import files\n", "from IPython.display import display, Audio\n", "import soundfile\n", "\n", "uploaded = files.upload()\n", "\n", "for file_name in uploaded.keys():\n", " mixwav_yours, rate = soundfile.read(file_name)\n", " assert rate == sr, \"mismatch in sampling rate\"\n", " waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)\n", " print(\"Input mixture\", flush=True)\n", " display(Audio(mixwav_yours, rate=sr))\n", " print(f\"========= Separated speech with model {tag} =========\", flush=True)\n", " print(\"Separated spk1\", flush=True)\n", " display(Audio(waves_yours[0].squeeze(), rate=sr))\n", " print(\"Separated spk2\", flush=True)\n", " display(Audio(waves_yours[1].squeeze(), rate=sr))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "A-LrSAKPgzSZ" }, "source": [ "#### Show spectrums of separated speech" ] }, { "cell_type": "code", "metadata": { "id": "MaYTZqXtGJEJ" }, "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "from torch_complex.tensor import ComplexTensor\n", "\n", "from espnet.asr.asr_utils import plot_spectrogram\n", "from espnet2.layers.stft import Stft\n", "\n", "\n", "stft = Stft(\n", " n_fft=512,\n", " win_length=None,\n", " hop_length=128,\n", " window=\"hann\",\n", ")\n", "ilens = torch.LongTensor([len(mixwav)])\n", "# specs: (T, F)\n", "spec_mix = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "spec_sep1 = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "spec_sep2 = ComplexTensor(\n", " *torch.unbind(\n", " stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),\n", " dim=-1\n", " )\n", ")\n", "\n", "# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])\n", "# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])\n", "samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))\n", "plt.figure(figsize=(24, 12))\n", "plt.subplot(3, 2, 1)\n", "plt.title('Mixture Spectrogram')\n", "plot_spectrogram(\n", " plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 2)\n", "plt.title('Mixture Wavform')\n", "plt.plot(samples, mixwav)\n", "plt.xlim(0, len(mixwav) / sr)\n", "\n", "plt.subplot(3, 2, 3)\n", "plt.title('Separated Spectrogram (spk1)')\n", "plot_spectrogram(\n", " plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 4)\n", "plt.title('Separated Wavform (spk1)')\n", "plt.plot(samples, waves_wsj[0].squeeze())\n", "plt.xlim(0, len(mixwav) / sr)\n", "\n", "plt.subplot(3, 2, 5)\n", "plt.title('Separated Spectrogram (spk2)')\n", "plot_spectrogram(\n", " plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,\n", " mode='db', frame_shift=None,\n", " bottom=False, labelbottom=False\n", ")\n", "plt.subplot(3, 2, 6)\n", "plt.title('Separated Wavform (spk2)')\n", "plt.plot(samples, waves_wsj[1].squeeze())\n", "plt.xlim(0, len(mixwav) / sr)\n", "plt.xlabel(\"Time (s)\")\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BnuxcpSicZsx" }, "source": [ "## Evluate separated speech with pretrained ASR model\n", "\n", "The ground truths are:\n", "\n", "`text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR`\n", "\n", "`text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK`\n", "\n", "(This may take a while for the speech recognition.)" ] }, { "cell_type": "code", "metadata": { "id": "Lb799sYpoOWO" }, "source": [ "%pip install -q https://github.com/kpu/kenlm/archive/master.zip # ASR need kenlm" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SZMUsNEQUAEU" }, "source": [ "import espnet_model_zoo\n", "from espnet_model_zoo.downloader import ModelDownloader\n", "from espnet2.bin.asr_inference import Speech2Text\n", "\n", "wsj_8k_model_url=\"https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1\"\n", "\n", "d = ModelDownloader()\n", "speech2text = Speech2Text(\n", " **d.download_and_unpack(wsj_8k_model_url),\n", " device=\"cuda:0\",\n", ")\n", "\n", "text_est = [None, None]\n", "text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]\n", "text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]\n", "text_m, *_ = speech2text(mixwav)[0]\n", "print(\"Mix Speech to Text: \", text_m)\n", "print(\"Separated Speech 1 to Text: \", text_est[0])\n", "print(\"Separated Speech 2 to Text: \", text_est[1])\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GrAbO58MXrji" }, "source": [ "import difflib\n", "from itertools import permutations\n", "\n", "import editdistance\n", "import numpy as np\n", "\n", "colors = dict(\n", " red=lambda text: f\"\\033[38;2;255;0;0m{text}\\033[0m\" if text else \"\",\n", " green=lambda text: f\"\\033[38;2;0;255;0m{text}\\033[0m\" if text else \"\",\n", " yellow=lambda text: f\"\\033[38;2;225;225;0m{text}\\033[0m\" if text else \"\",\n", " white=lambda text: f\"\\033[38;2;255;255;255m{text}\\033[0m\" if text else \"\",\n", " black=lambda text: f\"\\033[38;2;0;0;0m{text}\\033[0m\" if text else \"\",\n", ")\n", "\n", "def diff_strings(ref, est):\n", " \"\"\"Reference: https://stackoverflow.com/a/64404008/7384873\"\"\"\n", " ref_str, est_str, err_str = [], [], []\n", " matcher = difflib.SequenceMatcher(None, ref, est)\n", " for opcode, a0, a1, b0, b1 in matcher.get_opcodes():\n", " if opcode == \"equal\":\n", " txt = ref[a0:a1]\n", " ref_str.append(txt)\n", " est_str.append(txt)\n", " err_str.append(\" \" * (a1 - a0))\n", " elif opcode == \"insert\":\n", " ref_str.append(\"*\" * (b1 - b0))\n", " est_str.append(colors[\"green\"](est[b0:b1]))\n", " err_str.append(colors[\"black\"](\"I\" * (b1 - b0)))\n", " elif opcode == \"delete\":\n", " ref_str.append(ref[a0:a1])\n", " est_str.append(colors[\"red\"](\"*\" * (a1 - a0)))\n", " err_str.append(colors[\"black\"](\"D\" * (a1 - a0)))\n", " elif opcode == \"replace\":\n", " diff = a1 - a0 - b1 + b0\n", " if diff >= 0:\n", " txt_ref = ref[a0:a1]\n", " txt_est = colors[\"yellow\"](est[b0:b1]) + colors[\"red\"](\"*\" * diff)\n", " txt_err = \"S\" * (b1 - b0) + \"D\" * diff\n", " elif diff < 0:\n", " txt_ref = ref[a0:a1] + \"*\" * -diff\n", " txt_est = colors[\"yellow\"](est[b0:b1]) + colors[\"green\"](\"*\" * -diff)\n", " txt_err = \"S\" * (b1 - b0) + \"I\" * -diff\n", "\n", " ref_str.append(txt_ref)\n", " est_str.append(txt_est)\n", " err_str.append(colors[\"black\"](txt_err))\n", " return \"\".join(ref_str), \"\".join(est_str), \"\".join(err_str)\n", "\n", "\n", "text_ref = [\n", " \"SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR\",\n", " \"THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK\",\n", "]\n", "\n", "print(\"=====================\" , flush=True)\n", "perms = list(permutations(range(2)))\n", "string_edit = [\n", " [\n", " editdistance.eval(text_ref[m], text_est[n])\n", " for m, n in enumerate(p)\n", " ]\n", " for p in perms\n", "]\n", "\n", "dist = [sum(edist) for edist in string_edit]\n", "perm_idx = np.argmin(dist)\n", "perm = perms[perm_idx]\n", "\n", "for i, p in enumerate(perm):\n", " print(\"\\n--------------- Text %d ---------------\" % (i + 1), flush=True)\n", " ref, est, err = diff_strings(text_ref[i], text_est[p])\n", " print(\"REF: \" + ref + \"\\n\" + \"HYP: \" + est + \"\\n\" + \"ERR: \" + err, flush=True)\n", " print(\"Edit Distance = {}\\n\".format(string_edit[perm_idx][i]), flush=True)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "L0oHEt8O2hwI" }, "source": [ "# (2) Tutorials on Contributing to ESPNet-SE Project" ] }, { "cell_type": "markdown", "metadata": { "id": "oyCBv8j74VB9" }, "source": [ "If you would like to contribute to the ESPnet-SE project, or if you would like to make modifications based on the current speech enhancement/separation functionality, the following tutorials will provide you detailed information about how to creating new recipes or new models in ESPnet-SE." ] }, { "cell_type": "markdown", "metadata": { "id": "I-4b_OsU2kJS" }, "source": [ "## Creating a New Recipe" ] }, { "cell_type": "markdown", "metadata": { "id": "Jifpg1Al2sSx" }, "source": [ "### Step 1 Create recipe directory\n", "First, run the following command to create the directory for the new recipe from our template:\n", "```bash\n", "egs2/TEMPLATE/enh1/setup.sh egs2//enh1\n", "```\n", "\n", "> For the following steps, we assume the operations are done under the directory `egs2//enh1/`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AKs9UK9I5TyP" }, "source": [ "### Step 2 Write scripts for data preparation\n", "Prepare `local/data.sh`, which will be used in stage 1 in `enh.sh`.\n", "It can take some arguments as input, see [egs2/wsj0_2mix/enh1/local/data.sh](https://github.com/espnet/espnet/blob/master/egs2/wsj0_2mix/enh1/local/data.sh) for reference.\n", "\n", "The script `local/data.sh` should finally generate Kaldi-style data directories under `/data/`.\n", "Each subset directory should contains at least 4 files:\n", "```\n", "/data//\n", "├── spk{1,2,3...}.scp (clean speech references)\n", "├── spk2utt\n", "├── utt2spk\n", "└── wav.scp (noisy speech)\n", "```\n", "Optionally, it can also contain `noise{}.scp` and `dereverb{}.scp`, which point to the corresponding noise and dereverberated references respectively. {} can be 1, 2, ..., depending on the number of noise types (dereverberated signals) in the input signal in `wav.scp`.\n", "\n", "Make sure to sort the scp and other related files as in Kaldi. Also, remember to run `. ./path.sh` in `local/data.sh` before sorting, because it will force sorting to be byte-wise, i.e. `export LC_ALL=C`.\n", "\n", "> Remember to check your new scripts with [shellcheck](https://www.shellcheck.net/), otherwise they may fail the tests in [ci/test_shell.sh](https://github.com/espnet/espnet/blob/master/ci/test_shell.sh).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vMLSlfxC5V7B" }, "source": [ "### Step 3 Prepare training configuration\n", "Prepare training configuration files (e.g. [train.yaml](https://github.com/espnet/espnet/blob/master/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_rnn_tf.yaml)) under `conf/`.\n", "\n", "> If you have multiple configuration files, it is recommended to put them under `conf/tuning/`, and create a symbolic link `conf/tuning/train.yaml` pointing to the config file with the best performance.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "RaPTjGXi5XqQ" }, "source": [ "### Step 4 Prepare run.sh\n", "Write `run.sh` to provide a template entry script, so that users can easily run your recipe by `./run.sh`.\n", "See [egs2/wsj0_2mix/enh1/run.sh](https://github.com/espnet/espnet/blob/master/egs2/wsj0_2mix/enh1/run.sh) for reference.\n", "\n", "> If your recipes provide references for noise and/or dereverberation, you can add the argument `--use_noise_ref true` and/or `--use_dereverb_ref true` in `run.sh`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "OmcW83Xp2q-N" }, "source": [ "## Implementing a New Speech Enhancement/Separation Model\n", "\n", "The current ESPnet-SE tool adopts an encoder-separator-decoder architecture for all models, e.g.\n", "\n", "> For Time-Frequency masking models, the encoder and decoder would be [stft_encoder.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/encoder/stft_encoder.py) and [stft_decoder.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/decoder/stft_decoder.py) respectively, and the separator can be any of [dprnn_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/dprnn_separator.py), [rnn_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/rnn_separator.py), [tcn_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tcn_separator.py), and [transformer_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/transformer_separator.py).\n", "> For TasNet, the encoder and decoder are [conv_encoder.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/encoder/conv_encoder.py) and [conv_decoder.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/decoder/conv_decoder.py) respectively. The separator is [tcn_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tcn_separator.py)." ] }, { "cell_type": "markdown", "metadata": { "id": "MHoEHOtr2s3p" }, "source": [ "\n", "### Step 1 Create model scripts\n", "For encoder, separator, and decoder models, create new scripts under [espnet2/enh/encoder/](https://github.com/espnet/espnet/tree/master/espnet2/enh/encoder), [espnet2/enh/separator/](https://github.com/espnet/espnet/tree/master/espnet2/enh/separator), and [espnet2/enh/decoder/](https://github.com/espnet/espnet/tree/master/espnet2/enh/decoder), respectively.\n", "\n", "For a separator model, please make sure it implements the `num_spk` property. See [espnet2/enh/separator/rnn_separator.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/rnn_separator.py) for reference.\n", "\n", "> Remember to format your new scripts to match the styles in `black` and `flake8`, otherwise they may fail the tests in [ci/test_python.sh](https://github.com/espnet/espnet/blob/master/ci/test_python.sh).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dlHoi-TS4iB6" }, "source": [ "### Step 2 Add the new model to related scripts\n", "In [espnet2/tasks/enh.py](https://github.com/espnet/espnet/blob/master/espnet2/tasks/enh.py#L37-L62), add your new model to the corresponding `ClassChoices`, e.g.\n", "* For encoders, add `=` to `encoder_choices`.\n", "* For decoders, add `=` to `decoder_choices`.\n", "* For separators, add `=` to `separator_choices`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "57rcjJod4iIv" }, "source": [ "### Step 3 [Optional] Create new loss functions\n", "If you want to use a new loss function for your model, you can add it to [espnet2/enh/espnet_model.py](https://github.com/espnet/espnet/blob/master/espnet2/enh/espnet_model.py), such as:\n", "```python\n", " @staticmethod\n", " def new_loss(ref, inf):\n", " \"\"\"Your new loss\n", " Args:\n", " ref: (Batch, samples)\n", " inf: (Batch, samples)\n", " Returns:\n", " loss: (Batch,)\n", " \"\"\"\n", " ...\n", " return loss\n", "```\n", "\n", "Then add your loss name to [ALL_LOSS_TYPES](https://github.com/espnet/espnet/blob/master/espnet2/enh/espnet_model.py#L21), and handle the loss calculation in [_compute_loss](https://github.com/espnet/espnet/blob/master/espnet2/enh/espnet_model.py#L246).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FKcyPZXO4iMb" }, "source": [ "### Step 4 Create unit tests for the new model\n", "Finally, it would be nice to make some unit tests for your new model under [test/espnet2/enh/encoder](https://github.com/espnet/espnet/tree/master/test/espnet2/enh/encoder), [test/espnet2/enh/decoder](https://github.com/espnet/espnet/tree/master/test/espnet2/enh/decoder), or [test/espnet2/enh/separator](https://github.com/espnet/espnet/tree/master/test/espnet2/enh/separator).\n" ] } ] }