{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# ESPNET 2 pass SLU Demonstration\n", "\n", "This notebook provides a demonstration of the Two Pass End-to-End Spoken Language Understanding model\n", "\n", "Paper Link: https://arxiv.org/abs/2207.06670\n", "\n", "ESPnet2-SLU: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/slu1\n", "\n", "Author: Siddhant Arora" ], "metadata": { "id": "dlicd4tC4xjp" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Qvsfvkm4b5k" }, "outputs": [], "source": [ "! python -m pip install transformers\n", "!git clone https://github.com/espnet/espnet /espnet\n", "!pip install /espnet\n", "%pip install -q espnet_model_zoo\n", "%pip install fairseq@git+https://github.com//pytorch/fairseq.git@f2146bdc7abf293186de9449bfa2272775e39e1d#egg=fairseq" ] }, { "cell_type": "markdown", "source": [ "## Download Audio File" ], "metadata": { "id": "AwRW0VP5nwwJ" } }, { "cell_type": "code", "source": [ "# !gdown --id 1LxoxCoFgx3u8CvKb1loybGFtArKKPcAH -O /content/audio_file.wav\n", "!gdown --id 18ANT62ittt7Ai2E8bQRlvT0ZVXXsf1eE -O /content/audio_file.wav" ], "metadata": { "id": "r1oiDIEon9dv" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import os\n", "\n", "import soundfile\n", "from IPython.display import display, Audio\n", "mixwav_mc, sr = soundfile.read(\"/content/audio_file.wav\")\n", "display(Audio(mixwav_mc.T, rate=sr))" ], "metadata": { "id": "lMy4Vwoc0MPo" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Download and Load pretrained First Pass Model" ], "metadata": { "id": "x526Ibtjnv0N" } }, { "cell_type": "code", "source": [ "!git lfs clone https://huggingface.co/espnet/siddhana_slurp_new_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best /content/slurp_first_pass_model" ], "metadata": { "id": "aQNFXcMuo4iJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from espnet2.bin.asr_inference import Speech2Text\n", "speech2text_slurp = Speech2Text.from_pretrained(\n", " asr_train_config=\"/content/slurp_first_pass_model/exp/asr_train_asr_conformer_raw_en_word/config.yaml\",\n", " asr_model_file=\"/content/slurp_first_pass_model/exp/asr_train_asr_conformer_raw_en_word/valid.acc.ave_10best.pth\",\n", " nbest=1,\n", ")" ], "metadata": { "id": "9Avmes1H1Xgo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "nbests_orig = speech2text_slurp(mixwav_mc)\n", "text, *_ = nbests_orig[0]\n", "def text_normalizer(sub_word_transcript):\n", " transcript = sub_word_transcript[0].replace(\"▁\", \"\")\n", " for sub_word in sub_word_transcript[1:]:\n", " if \"▁\" in sub_word:\n", " transcript = transcript + \" \" + sub_word.replace(\"▁\", \"\")\n", " else:\n", " transcript = transcript + sub_word\n", " return transcript\n", "intent_text=\"{scenario: \"+text.split()[0].split(\"_\")[0]+\", action: \"+\"_\".join(text.split()[0].split(\"_\")[1:])+\"}\"\n", "print(f\"INTENT: {intent_text}\")\n", "transcript=text_normalizer(text.split()[1:])\n", "print(f\"ASR hypothesis: {transcript}\")\n", "print(f\"First pass SLU model fails to predict the correct action.\")" ], "metadata": { "id": "8Ope3h7P3aKz" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Download and Load pretrained Second Pass Model" ], "metadata": { "id": "tKB6r3c-398T" } }, { "cell_type": "code", "source": [ "!git lfs clone https://huggingface.co/espnet/slurp_slu_2pass /content/slurp_second_pass_model" ], "metadata": { "id": "gFi7KOFE4HzK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from espnet2.bin.slu_inference import Speech2Understand\n", "from transformers import AutoModel, AutoTokenizer\n", "speech2text_second_pass_slurp = Speech2Understand.from_pretrained(\n", " slu_train_config=\"/content/slurp_second_pass_model/exp/slu_train_asr_bert_conformer_deliberation_raw_en_word/config.yaml\",\n", " slu_model_file=\"/content/slurp_second_pass_model/exp/slu_train_asr_bert_conformer_deliberation_raw_en_word/valid.acc.ave_10best.pth\",\n", " nbest=1,\n", ")" ], "metadata": { "id": "bt9UqWvj60K7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from espnet2.tasks.slu import SLUTask\n", "preprocess_fn=SLUTask.build_preprocess_fn(\n", " speech2text_second_pass_slurp.asr_train_args, False\n", " )\n" ], "metadata": { "id": "L4CvcTWHMg23" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import numpy as np\n", "transcript = preprocess_fn.text_cleaner(transcript)\n", "tokens = preprocess_fn.transcript_tokenizer.text2tokens(transcript)\n", "text_ints = np.array(preprocess_fn.transcript_token_id_converter.tokens2ids(tokens), dtype=np.int64)" ], "metadata": { "id": "eXK8jabBM7KL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "nbests = speech2text_second_pass_slurp(mixwav_mc,torch.tensor(text_ints))\n", "text1, *_ = nbests[0]\n", "intent_text=\"{scenario: \"+text1.split()[0].split(\"_\")[0]+\", action: \"+\"_\".join(text1.split()[0].split(\"_\")[1:])+\"}\"\n", "print(f\"INTENT: {intent_text}\")\n", "transcript=text_normalizer(text1.split()[1:])\n", "print(f\"ASR hypothesis: {transcript}\")\n", "print(f\"Second pass SLU model successfully recognizes the correct action.\")" ], "metadata": { "id": "TCELe9RS_62x" }, "execution_count": null, "outputs": [] } ] }