Source code for espnet2.gan_svs.visinger2.ddsp

import math

import librosa as li
import numpy as np
import torch
import torch.fft as fft
import torch.nn as nn
from scipy.signal import get_window


[docs]def safe_log(x): return torch.log(x + 1e-7)
[docs]@torch.no_grad() def mean_std_loudness(dataset): mean = 0 std = 0 n = 0 for _, _, l in dataset: n += 1 mean += (l.mean().item() - mean) / n std += (l.std().item() - std) / n return mean, std
[docs]def multiscale_fft(signal, scales, overlap): stfts = [] for s in scales: S = torch.stft( signal, s, int(s * (1 - overlap)), s, torch.hann_window(s).to(signal), True, normalized=True, return_complex=True, ).abs() stfts.append(S) return stfts
[docs]def resample(x, factor: int): batch, frame, channel = x.shape x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame) window = torch.hann_window( factor * 2, dtype=x.dtype, device=x.device, ).reshape(1, 1, -1) y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x) y[..., ::factor] = x y[..., -1:] = x[..., -1:] y = torch.nn.functional.pad(y, [factor, factor]) y = torch.nn.functional.conv1d(y, window)[..., :-1] y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1) return y
[docs]def upsample(signal, factor): signal = signal.permute(0, 2, 1) signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor) return signal.permute(0, 2, 1)
[docs]def remove_above_nyquist(amplitudes, pitch, sampling_rate): n_harm = amplitudes.shape[-1] pitches = pitch * torch.arange(1, n_harm + 1).to(pitch) aa = (pitches < sampling_rate / 2).float() + 1e-4 return amplitudes * aa
[docs]def scale_function(x): return 2 * torch.sigmoid(x) ** (math.log(10)) + 1e-7
[docs]def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): S = li.stft( signal, n_fft=n_fft, hop_length=block_size, win_length=n_fft, center=True, ) S = np.log(abs(S) + 1e-7) f = li.fft_frequencies(sampling_rate, n_fft) a_weight = li.A_weighting(f) S = S + a_weight.reshape(-1, 1) S = np.mean(S, 0)[..., :-1] return S
# TODO(Yifeng): Some functions are not used here such as crepe, # maybe we can remove them later or only import used functions.
[docs]def extract_pitch(signal, sampling_rate, block_size): length = signal.shape[-1] // block_size f0 = crepe.predict( # noqa signal, sampling_rate, step_size=int(1000 * block_size / sampling_rate), verbose=1, center=True, viterbi=True, ) f0 = f0[1].reshape(-1)[:-1] if f0.shape[-1] != length: f0 = np.interp( np.linspace(0, 1, length, endpoint=False), np.linspace(0, 1, f0.shape[-1], endpoint=False), f0, ) return f0
[docs]def mlp(in_size, hidden_size, n_layers): channels = [in_size] + (n_layers) * [hidden_size] net = [] for i in range(n_layers): net.append(nn.Linear(channels[i], channels[i + 1])) net.append(nn.LayerNorm(channels[i + 1])) net.append(nn.LeakyReLU()) return nn.Sequential(*net)
[docs]def gru(n_input, hidden_size): return nn.GRU(n_input * hidden_size, hidden_size, batch_first=True)
[docs]def harmonic_synth(pitch, amplitudes, sampling_rate): n_harmonic = amplitudes.shape[-1] omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1) omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True) return signal
[docs]def amp_to_impulse_response(amp, target_size): amp = torch.stack([amp, torch.zeros_like(amp)], -1) amp = torch.view_as_complex(amp) amp = fft.irfft(amp) filter_size = amp.shape[-1] amp = torch.roll(amp, filter_size // 2, -1) win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) amp = amp * win amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size))) amp = torch.roll(amp, -filter_size // 2, -1) return amp
[docs]def fft_convolve(signal, kernel): signal = nn.functional.pad(signal, (0, signal.shape[-1])) kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) output = output[..., output.shape[-1] // 2 :] return output
[docs]def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): if win_type == "None" or win_type is None: window = np.ones(win_len) else: window = get_window(win_type, win_len, fftbins=True) # **0.5 N = fft_len fourier_basis = np.fft.rfft(np.eye(N))[:win_len] real_kernel = np.real(fourier_basis) imag_kernel = np.imag(fourier_basis) kernel = np.concatenate([real_kernel, imag_kernel], 1).T if invers: kernel = np.linalg.pinv(kernel).T kernel = kernel * window kernel = kernel[:, None, :] return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( window[None, :, None].astype(np.float32) )