espnet.nets.pytorch_backend.wavenet.WaveNet
Less than 1 minute
espnet.nets.pytorch_backend.wavenet.WaveNet
class espnet.nets.pytorch_backend.wavenet.WaveNet(n_quantize=256, n_aux=28, n_resch=512, n_skipch=256, dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0)
Bases: Module
Conditional wavenet.
- Parameters:
- n_quantize (int) – Number of quantization.
- n_aux (int) – Number of aux feature dimension.
- n_resch (int) – Number of filter channels for residual block.
- n_skipch (int) – Number of filter channels for skip connection.
- dilation_depth (int) – Number of dilation depth (e.g. if set 10, max dilation = 2^(10-1)).
- dilation_repeat (int) – Number of dilation repeat.
- kernel_size (int) – Filter size of dilated causal convolution.
- upsampling_factor (int) – Upsampling factor.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(x, h)
Calculate forward propagation.
- Parameters:
- x (LongTensor) – Quantized input waveform tensor with the shape (B, T).
- h (Tensor) – Auxiliary feature tensor with the shape (B, n_aux, T).
- Returns: Logits with the shape (B, T, n_quantize).
- Return type: Tensor
generate(x, h, n_samples, interval=None, mode='sampling')
Generate a waveform with fast genration algorithm.
This generation based on Fast WaveNet Generation Algorithm.
- Parameters:
- x (LongTensor) – Initial waveform tensor with the shape (T,).
- h (Tensor) – Auxiliary feature tensor with the shape (n_samples + T, n_aux).
- n_samples (int) – Number of samples to be generated.
- interval (int , optional) – Log interval.
- mode (str , optional) – “sampling” or “argmax”.
- Returns: Generated quantized waveform (n_samples).
- Return type: ndarray