espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
class espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer(dimension: int = 256, codebook_dim: int = 512, n_q: int = 8, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, quantizer_dropout: bool = False)
Bases: Module
Residual Vector Quantizer.
- Parameters:
- dimension (int) – Dimension of the codebooks.
- n_q (int) – Number of residual vector quantizers used.
- bins (int) – Codebook size.
- decay (float) – Decay for exponential moving average over the codebooks.
- kmeans_init (bool) – Whether to use kmeans to initialize the codebooks.
- kmeans_iters (int) – Number of iterations used for kmeans initialization.
- threshold_ema_dead_code (int) – Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
decode(codes: Tensor) → Tensor
Decode the given codes to the quantized representation.
encode(x: Tensor, sample_rate: int, bandwidth: float | None = None, st: int | None = None) → Tensor
Encode a given input tensor with the specified sample rate at
the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer.
forward(x: Tensor, sample_rate: int, bandwidth: float | None = None) → QuantizedResult
Residual vector quantization on the given input tensor.
- Parameters:
- x (torch.Tensor) – Input tensor.
- sample_rate (int) – Sample rate of the input tensor.
- bandwidth (float) – Target bandwidth.
- Returns: The quantized (or approximately quantized) representation with the associated bandwidth and any penalty term for the loss.
- Return type:QuantizedResult
get_bandwidth_per_quantizer(sample_rate: int)
Return bandwidth per quantizer for a given input sample rate.
get_num_quantizers_for_bandwidth(sample_rate: int, bandwidth: float | None = None) → int
Return n_q based on specified target bandwidth.