Source code for espnet2.tts.utils.parallel_wavegan_pretrained_vocoder

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Wrapper class for the vocoder model trained with parallel_wavegan repo."""

import logging
import os
from pathlib import Path
from typing import Optional, Union

import torch
import yaml


[docs]class ParallelWaveGANPretrainedVocoder(torch.nn.Module): """Wrapper class to load the vocoder trained with parallel_wavegan repo.""" def __init__( self, model_file: Union[Path, str], config_file: Optional[Union[Path, str]] = None, ): """Initialize ParallelWaveGANPretrainedVocoder module.""" super().__init__() try: from parallel_wavegan.utils import load_model except ImportError: logging.error( "`parallel_wavegan` is not installed. " "Please install via `pip install -U parallel_wavegan`." ) raise if config_file is None: dirname = os.path.dirname(str(model_file)) config_file = os.path.join(dirname, "config.yml") with open(config_file) as f: config = yaml.load(f, Loader=yaml.Loader) self.fs = config["sampling_rate"] self.vocoder = load_model(model_file, config) if hasattr(self.vocoder, "remove_weight_norm"): self.vocoder.remove_weight_norm() self.normalize_before = False if hasattr(self.vocoder, "mean"): self.normalize_before = True
[docs] @torch.no_grad() def forward(self, feats: torch.Tensor) -> torch.Tensor: """Generate waveform with pretrained vocoder. Args: feats (Tensor): Feature tensor (T_feats, #mels). Returns: Tensor: Generated waveform tensor (T_wav). """ return self.vocoder.inference( feats, normalize_before=self.normalize_before, ).view(-1)