Source code for espnet.nets.pytorch_backend.wavenet

# -*- coding: utf-8 -*-

# Copyright 2019 Tomoki Hayashi (Nagoya University)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""

import logging
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


[docs]def encode_mu_law(x, mu=256): """Perform mu-law encoding. Args: x (ndarray): Audio signal with the range from -1 to 1. mu (int): Quantized level. Returns: ndarray: Quantized audio signal with the range from 0 to mu - 1. """ mu = mu - 1 fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64)
[docs]def decode_mu_law(y, mu=256): """Perform mu-law decoding. Args: x (ndarray): Quantized audio signal with the range from 0 to mu - 1. mu (int): Quantized level. Returns: ndarray: Audio signal with the range from -1 to 1. """ mu = mu - 1 fx = (y - 0.5) / mu * 2 - 1 x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1) return x
[docs]def initialize(m): """Initilize conv layers with xavier. Args: m (torch.nn.Module): Torch module. """ if isinstance(m, nn.Conv1d): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.0) if isinstance(m, nn.ConvTranspose2d): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0)
[docs]class OneHot(nn.Module): """Convert to one-hot vector. Args: depth (int): Dimension of one-hot vector. """ def __init__(self, depth): super(OneHot, self).__init__() self.depth = depth
[docs] def forward(self, x): """Calculate forward propagation. Args: x (LongTensor): long tensor variable with the shape (B, T) Returns: Tensor: float tensor variable with the shape (B, depth, T) """ x = x % self.depth x = torch.unsqueeze(x, 2) x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float() return x_onehot.scatter_(2, x, 1)
[docs]class CausalConv1d(nn.Module): """1D dilated causal convolution.""" def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True): super(CausalConv1d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.dilation = dilation self.padding = padding = (kernel_size - 1) * dilation self.conv = nn.Conv1d( in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, )
[docs] def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input tensor with the shape (B, in_channels, T). Returns: Tensor: Tensor with the shape (B, out_channels, T) """ x = self.conv(x) if self.padding != 0: x = x[:, :, : -self.padding] return x
[docs]class UpSampling(nn.Module): """Upsampling layer with deconvolution. Args: upsampling_factor (int): Upsampling factor. """ def __init__(self, upsampling_factor, bias=True): super(UpSampling, self).__init__() self.upsampling_factor = upsampling_factor self.bias = bias self.conv = nn.ConvTranspose2d( 1, 1, kernel_size=(1, self.upsampling_factor), stride=(1, self.upsampling_factor), bias=self.bias, )
[docs] def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input tensor with the shape (B, C, T) Returns: Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor. """ x = x.unsqueeze(1) # B x 1 x C x T x = self.conv(x) # B x 1 x C x T' return x.squeeze(1)
[docs]class WaveNet(nn.Module): """Conditional wavenet. Args: n_quantize (int): Number of quantization. n_aux (int): Number of aux feature dimension. n_resch (int): Number of filter channels for residual block. n_skipch (int): Number of filter channels for skip connection. dilation_depth (int): Number of dilation depth (e.g. if set 10, max dilation = 2^(10-1)). dilation_repeat (int): Number of dilation repeat. kernel_size (int): Filter size of dilated causal convolution. upsampling_factor (int): Upsampling factor. """ def __init__( self, n_quantize=256, n_aux=28, n_resch=512, n_skipch=256, dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0, ): super(WaveNet, self).__init__() self.n_aux = n_aux self.n_quantize = n_quantize self.n_resch = n_resch self.n_skipch = n_skipch self.kernel_size = kernel_size self.dilation_depth = dilation_depth self.dilation_repeat = dilation_repeat self.upsampling_factor = upsampling_factor self.dilations = [ 2**i for i in range(self.dilation_depth) ] * self.dilation_repeat self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1 # for preprocessing self.onehot = OneHot(self.n_quantize) self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size) if self.upsampling_factor > 0: self.upsampling = UpSampling(self.upsampling_factor) # for residual blocks self.dil_sigmoid = nn.ModuleList() self.dil_tanh = nn.ModuleList() self.aux_1x1_sigmoid = nn.ModuleList() self.aux_1x1_tanh = nn.ModuleList() self.skip_1x1 = nn.ModuleList() self.res_1x1 = nn.ModuleList() for d in self.dilations: self.dil_sigmoid += [ CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d) ] self.dil_tanh += [ CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d) ] self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)] self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)] self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)] self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)] # for postprocessing self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1) self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)
[docs] def forward(self, x, h): """Calculate forward propagation. Args: x (LongTensor): Quantized input waveform tensor with the shape (B, T). h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T). Returns: Tensor: Logits with the shape (B, T, n_quantize). """ # preprocess output = self._preprocess(x) if self.upsampling_factor > 0: h = self.upsampling(h) # residual block skip_connections = [] for i in range(len(self.dilations)): output, skip = self._residual_forward( output, h, self.dil_sigmoid[i], self.dil_tanh[i], self.aux_1x1_sigmoid[i], self.aux_1x1_tanh[i], self.skip_1x1[i], self.res_1x1[i], ) skip_connections.append(skip) # skip-connection part output = sum(skip_connections) output = self._postprocess(output) return output
[docs] def generate(self, x, h, n_samples, interval=None, mode="sampling"): """Generate a waveform with fast genration algorithm. This generation based on `Fast WaveNet Generation Algorithm`_. Args: x (LongTensor): Initial waveform tensor with the shape (T,). h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux). n_samples (int): Number of samples to be generated. interval (int, optional): Log interval. mode (str, optional): "sampling" or "argmax". Return: ndarray: Generated quantized waveform (n_samples). .. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482 """ # reshape inputs assert len(x.shape) == 1 assert len(h.shape) == 2 and h.shape[1] == self.n_aux x = x.unsqueeze(0) h = h.transpose(0, 1).unsqueeze(0) # perform upsampling if self.upsampling_factor > 0: h = self.upsampling(h) # padding for shortage if n_samples > h.shape[2]: h = F.pad(h, (0, n_samples - h.shape[2]), "replicate") # padding if the length less than n_pad = self.receptive_field - x.size(1) if n_pad > 0: x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2) h = F.pad(h, (n_pad, 0), "replicate") # prepare buffer output = self._preprocess(x) h_ = h[:, :, : x.size(1)] output_buffer = [] buffer_size = [] for i, d in enumerate(self.dilations): output, _ = self._residual_forward( output, h_, self.dil_sigmoid[i], self.dil_tanh[i], self.aux_1x1_sigmoid[i], self.aux_1x1_tanh[i], self.skip_1x1[i], self.res_1x1[i], ) if d == 2 ** (self.dilation_depth - 1): buffer_size.append(self.kernel_size - 1) else: buffer_size.append(d * 2 * (self.kernel_size - 1)) output_buffer.append(output[:, :, -buffer_size[i] - 1 : -1]) # generate samples = x[0] start_time = time.time() for i in range(n_samples): output = samples[-self.kernel_size * 2 + 1 :].unsqueeze(0) output = self._preprocess(output) h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1) output_buffer_next = [] skip_connections = [] for j, d in enumerate(self.dilations): output, skip = self._generate_residual_forward( output, h_, self.dil_sigmoid[j], self.dil_tanh[j], self.aux_1x1_sigmoid[j], self.aux_1x1_tanh[j], self.skip_1x1[j], self.res_1x1[j], ) output = torch.cat([output_buffer[j], output], dim=2) output_buffer_next.append(output[:, :, -buffer_size[j] :]) skip_connections.append(skip) # update buffer output_buffer = output_buffer_next # get predicted sample output = sum(skip_connections) output = self._postprocess(output)[0] if mode == "sampling": posterior = F.softmax(output[-1], dim=0) dist = torch.distributions.Categorical(posterior) sample = dist.sample().unsqueeze(0) elif mode == "argmax": sample = output.argmax(-1) else: logging.error("mode should be sampling or argmax") sys.exit(1) samples = torch.cat([samples, sample], dim=0) # show progress if interval is not None and (i + 1) % interval == 0: elapsed_time_per_sample = (time.time() - start_time) / interval logging.info( "%d/%d estimated time = %.3f sec (%.3f sec / sample)" % ( i + 1, n_samples, (n_samples - i - 1) * elapsed_time_per_sample, elapsed_time_per_sample, ) ) start_time = time.time() return samples[-n_samples:].cpu().numpy()
def _preprocess(self, x): x = self.onehot(x).transpose(1, 2) output = self.causal(x) return output def _postprocess(self, x): output = F.relu(x) output = self.conv_post_1(output) output = F.relu(output) # B x C x T output = self.conv_post_2(output).transpose(1, 2) # B x T x C return output def _residual_forward( self, x, h, dil_sigmoid, dil_tanh, aux_1x1_sigmoid, aux_1x1_tanh, skip_1x1, res_1x1, ): output_sigmoid = dil_sigmoid(x) output_tanh = dil_tanh(x) aux_output_sigmoid = aux_1x1_sigmoid(h) aux_output_tanh = aux_1x1_tanh(h) output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh( output_tanh + aux_output_tanh ) skip = skip_1x1(output) output = res_1x1(output) output = output + x return output, skip def _generate_residual_forward( self, x, h, dil_sigmoid, dil_tanh, aux_1x1_sigmoid, aux_1x1_tanh, skip_1x1, res_1x1, ): output_sigmoid = dil_sigmoid(x)[:, :, -1:] output_tanh = dil_tanh(x)[:, :, -1:] aux_output_sigmoid = aux_1x1_sigmoid(h) aux_output_tanh = aux_1x1_tanh(h) output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh( output_tanh + aux_output_tanh ) skip = skip_1x1(output) output = res_1x1(output) output = output + x[:, :, -1:] # B x C x 1 return output, skip