Open In Colab

ESPnet real time E2E-TTS demonstration

This notebook provides a demonstration of the realtime E2E-TTS using ESPnet-TTS and ParallelWaveGAN (+ MelGAN).

Author: Tomoki Hayashi ([@kan-bayashi](https://github.com/kan-bayashi))

Install

[ ]:
# install minimal components
!pip install -q parallel_wavegan PyYaml unidecode ConfigArgparse g2p_en espnet_tts_frontend
!pip install --upgrade --no-cache-dir gdown
!git clone -q https://github.com/espnet/espnet.git
!cd espnet && git fetch && git checkout -b v.0.9.1 refs/tags/v.0.9.1

English demo

Download pretrained feature generation model

You can select one from three models. Please only run the seletected model cells.

(a) Tacotron2

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/tacotron2"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1lFfeyewyOsxaNO-DEWy9iSz6qB9ZS1UR downloads/en/tacotron2 tar.gz

# set path
trans_type = "phn"
dict_path = "downloads/en/tacotron2/data/lang_1phn/phn_train_no_dev_units.txt"
model_path = "downloads/en/tacotron2/exp/phn_train_no_dev_pytorch_train_pytorch_tacotron2.v3/results/model.last1.avg.best"

print("sucessfully finished download.")

(b) Transformer

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/transformer"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1z8KSOWVBjK-_Ws4RxVN4NTx-Buy03-7c downloads/en/transformer tar.gz

# set path
trans_type = "phn"
dict_path = "downloads/en/transformer/data/lang_1phn/phn_train_no_dev_units.txt"
model_path = "downloads/en/transformer/exp/phn_train_no_dev_pytorch_train_pytorch_transformer.v3.single/results/model.last1.avg.best"

print("sucessfully finished download.")

(c) FastSpeech

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/fastspeech"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1P9I4qag8wAcJiTCPawt6WCKBqUfJFtFp downloads/en/fastspeech tar.gz

# set path
trans_type = "phn"
dict_path = "downloads/en/fastspeech/data/lang_1phn/phn_train_no_dev_units.txt"
model_path = "downloads/en/fastspeech/exp/phn_train_no_dev_pytorch_train_tacotron2.v3_fastspeech.v4.single/results/model.last1.avg.best"

print("Sucessfully finished download.")

Download pretrained vocoder model

You can select one from two models. Please only run the seletected model cells.

(a) Parallel WaveGAN

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/parallel_wavegan"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1Grn7X9wD35UcDJ5F7chwdTqTa4U7DeVB downloads/en/parallel_wavegan tar.gz

# set path
vocoder_path = "downloads/en/parallel_wavegan/ljspeech.parallel_wavegan.v2/checkpoint-400000steps.pkl"

print("Sucessfully finished download.")

(b) MelGAN

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/melgan"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1_a8faVA5OGCzIcJNw4blQYjfG4oA9VEt downloads/en/melgan tar.gz

# set path
vocoder_path = "downloads/en/melgan/train_nodev_ljspeech_melgan.v3.long/checkpoint-4000000steps.pkl"

print("Sucessfully finished download.")

(c) Multi-band MelGAN

This is an EXPERIMENTAL model.

[ ]:
# download pretrained model
import os
if not os.path.exists("downloads/en/mb-melgan"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1rGG5y15uy4WZ-lJy8NPVTkmB_6VhC20V downloads/en/mb-melgan tar.gz

# set path
vocoder_path = "downloads/en/mb-melgan/train_nodev_ljspeech_multi_band_melgan.v1/checkpoint-1000000steps.pkl"

print("Sucessfully finished download.")

Setup

[ ]:
# add path
import sys
sys.path.append("espnet")

# define device
import torch
device = torch.device("cuda")

# define E2E-TTS model
from argparse import Namespace
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.utils.dynamic_import import dynamic_import
idim, odim, train_args = get_model_conf(model_path)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
model = model.eval().to(device)
inference_args = Namespace(**{
    "threshold": 0.5,"minlenratio": 0.0, "maxlenratio": 10.0,
    # Only for Tacotron 2
    "use_attention_constraint": True, "backward_window": 1,"forward_window":3,
    # Only for fastspeech (lower than 1.0 is faster speech, higher than 1.0 is slower speech)
    "fastspeech_alpha": 1.0,
    })

# define neural vocoder
from parallel_wavegan.utils import load_model
fs = 22050
vocoder = load_model(vocoder_path)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)

# define text frontend
from tacotron_cleaner.cleaners import custom_english_cleaners
from g2p_en import G2p
with open(dict_path) as f:
    lines = f.readlines()
lines = [line.replace("\n", "").split(" ") for line in lines]
char_to_id = {c: int(i) for c, i in lines}
g2p = G2p()
def frontend(text):
    """Clean text and then convert to id sequence."""
    text = custom_english_cleaners(text)

    if trans_type == "phn":
        text = filter(lambda s: s != " ", g2p(text))
        text = " ".join(text)
        print(f"Cleaned text: {text}")
        charseq = text.split(" ")
    else:
        print(f"Cleaned text: {text}")
        charseq = list(text)
    idseq = []
    for c in charseq:
        if c.isspace():
            idseq += [char_to_id["<space>"]]
        elif c not in char_to_id.keys():
            idseq += [char_to_id["<unk>"]]
        else:
            idseq += [char_to_id[c]]
    idseq += [idim - 1]  # <eos>
    return torch.LongTensor(idseq).view(-1).to(device)

import nltk
nltk.download('punkt')
print("Now ready to synthesize!")

Synthesis

[ ]:
import time
print("Input your favorite sentence in English!")
input_text = input()
with torch.no_grad():
    start = time.time()
    x = frontend(input_text)
    c, _, _ = model.inference(x, inference_args)
    y = vocoder.inference(c)
rtf = (time.time() - start) / (len(y) / fs)
print(f"RTF = {rtf:5f}")

from IPython.display import display, Audio
display(Audio(y.view(-1).cpu().numpy(), rate=fs))

Japanese demo

Install Japanese dependencies

[ ]:
!pip install pyopenjtalk

Download pretrained models

Here we select Tacotron2 or Transformer. The vocoder model is Parallel WaveGAN.

(a) Tacotron 2

[ ]:
# download pretrained models
import os
if not os.path.exists("downloads/jp/tacotron2"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1OwrUQzAmvjj1x9cDhnZPp6dqtsEqGEJM downloads/jp/tacotron2 tar.gz
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1kp5M4VvmagDmYckFJa78WGqh1drb_P9t downloads/jp/tacotron2 tar.gz

# set path
dict_path = "downloads/jp/tacotron2/data/lang_1phn/train_no_dev_units.txt"
model_path = "downloads/jp/tacotron2/exp/train_no_dev_pytorch_train_pytorch_tacotron2_phn/results/model.last1.avg.best"
vocoder_path = "downloads/jp/tacotron2/jsut.parallel_wavegan.v1/checkpoint-400000steps.pkl"

print("sucessfully finished download.")

(b) Transformer

[ ]:
# download pretrained models
import os
if not os.path.exists("downloads/jp/transformer"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1OwrUQzAmvjj1x9cDhnZPp6dqtsEqGEJM downloads/jp/transformer tar.gz
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1mEnZfBKqA4eT6Bn0eRZuP6lNzL-IL3VD downloads/jp/transformer tar.gz

# set path
dict_path = "downloads/jp/transformer/data/lang_1phn/train_no_dev_units.txt"
model_path = "downloads/jp/transformer/exp/train_no_dev_pytorch_train_pytorch_transformer_phn/results/model.last1.avg.best"
vocoder_path = "downloads/jp/transformer/jsut.parallel_wavegan.v1/checkpoint-400000steps.pkl"

print("sucessfully finished download.")

Setup

[ ]:
# add path
import sys
sys.path.append("espnet")

# define device
import torch
device = torch.device("cuda")

# define E2E-TTS model
from argparse import Namespace
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.utils.dynamic_import import dynamic_import
idim, odim, train_args = get_model_conf(model_path)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
model = model.eval().to(device)
inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})

# define neural vocoder
from parallel_wavegan.utils import load_model
fs = 24000
vocoder = load_model(vocoder_path)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)

# define text frontend
import pyopenjtalk
with open(dict_path) as f:
    lines = f.readlines()
lines = [line.replace("\n", "").split(" ") for line in lines]
char_to_id = {c: int(i) for c, i in lines}
def frontend(text):
    """Clean text and then convert to id sequence."""
    text = pyopenjtalk.g2p(text, kana=False)
    print(f"Cleaned text: {text}")
    charseq = text.split(" ")
    idseq = []
    for c in charseq:
        if c.isspace():
            idseq += [char_to_id["<space>"]]
        elif c not in char_to_id.keys():
            idseq += [char_to_id["<unk>"]]
        else:
            idseq += [char_to_id[c]]
    idseq += [idim - 1]  # <eos>
    return torch.LongTensor(idseq).view(-1).to(device)

frontend("初回の辞書のインストールが必要です")
print("Now ready to synthesize!")

Synthesis

[ ]:
import time
print("日本語で好きな文章を入力してください")
input_text = input()

with torch.no_grad():
    start = time.time()
    x = frontend(input_text)
    c, _, _ = model.inference(x, inference_args)
    y = vocoder.inference(c)
rtf = (time.time() - start) / (len(y) / fs)
print(f"RTF = {rtf:5f}")

from IPython.display import display, Audio
display(Audio(y.view(-1).cpu().numpy(), rate=fs))

Mandarin demo

IMPORTANT NOTE: The author cannot understand Mandarin. The text front-end part might have some bugs.

Install Mandarin dependencies

[ ]:
!pip install pypinyin

Download pretrained models

You can select Transformer or FastSpeech.

(a) Transformer

[ ]:
# download pretrained models
import os
if not os.path.exists("downloads/zh/transformer"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=10M6H88jEUGbRWBmU1Ff2VaTmOAeL8CEy downloads/zh/transformer tar.gz
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1bTSygvonv5TS6-iuYsOIUWpN2atGnyhZ downloads/zh/transformer tar.gz

# set path
dict_path = "downloads/zh/transformer/data/lang_phn/train_no_dev_units.txt"
model_path = "downloads/zh/transformer/exp/train_no_dev_pytorch_train_pytorch_transformer.v1.single/results/model.last1.avg.best"
vocoder_path = "downloads/zh/transformer/csmsc.parallel_wavegan.v1/checkpoint-400000steps.pkl"

print("sucessfully finished download.")

(b) FastSpeech

[ ]:
# download pretrained models
import os
if not os.path.exists("downloads/zh/fastspeech"):
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=10M6H88jEUGbRWBmU1Ff2VaTmOAeL8CEy downloads/zh/fastspeech tar.gz
    !./espnet/utils/download_from_google_drive.sh \
        https://drive.google.com/open?id=1T8thxkAxjGFPXPWPTcKLvHnd6lG0-82R downloads/zh/fastspeech tar.gz

# set path
dict_path = "downloads/zh/fastspeech/data/lang_phn/train_no_dev_units.txt"
model_path = "downloads/zh/fastspeech/exp/train_no_dev_pytorch_train_fastspeech.v3.single/results/model.last1.avg.best"
vocoder_path = "downloads/zh/fastspeech/csmsc.parallel_wavegan.v1/checkpoint-400000steps.pkl"

print("sucessfully finished download.")

Setup

[ ]:
# add path
import sys
sys.path.append("espnet")

# define device
import torch
device = torch.device("cuda")

# define E2E-TTS model
from argparse import Namespace
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.utils.dynamic_import import dynamic_import
idim, odim, train_args = get_model_conf(model_path)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
model = model.eval().to(device)
inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})

# define neural vocoder
from parallel_wavegan.utils import load_model
fs = 24000
vocoder = load_model(vocoder_path)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)

# define text frontend
from pypinyin import pinyin, Style
from pypinyin.style._utils import get_initials, get_finals
with open(dict_path) as f:
    lines = f.readlines()
lines = [line.replace("\n", "").split(" ") for line in lines]
char_to_id = {c: int(i) for c, i in lines}
def frontend(text):
    """Clean text and then convert to id sequence."""
    text = pinyin(text, style=Style.TONE3)
    text = [c[0] for c in text]
    print(f"Cleaned text: {text}")
    idseq = []
    for x in text:
        c_init = get_initials(x, strict=True)
        c_final = get_finals(x, strict=True)
        for c in [c_init, c_final]:
            if len(c) == 0:
                continue
            c = c.replace("ü", "v")
            c = c.replace("ui", "uei")
            c = c.replace("un", "uen")
            c = c.replace("iu", "iou")
            # Special rule: "e5n" -> "en5"
            if "5" in c:
                c = c.replace("5", "") + "5"
            if c not in char_to_id.keys():
                print(f"WARN: {c} is not included in dict.")
                idseq += [char_to_id["<unk>"]]
            else:
                idseq += [char_to_id[c]]
    idseq += [idim - 1]  # <eos>
    return torch.LongTensor(idseq).view(-1).to(device)

print("now ready to synthesize!")

Synthesis

[ ]:
import time
print("請用中文輸入您喜歡的句子!")
input_text = input()

with torch.no_grad():
    start = time.time()
    x = frontend(input_text)
    c, _, _ = model.inference(x, inference_args)
    y = vocoder.inference(c)
rtf = (time.time() - start) / (len(y) / fs)
print(f"RTF = {rtf:5f}")

from IPython.display import display, Audio
display(Audio(y.view(-1).cpu().numpy(), rate=fs))