espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
class espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer(dimension: int = 256, codebook_dim: int = 512, n_q: int = 8, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, quantizer_dropout: bool = False)
Bases: Module
Residual Vector Quantizer. :param dimension: Dimension of the codebooks. :type dimension: int :param n_q: Number of residual vector quantizers used. :type n_q: int :param bins: Codebook size. :type bins: int :param decay: Decay for exponential moving average over the codebooks. :type decay: float :param kmeans_init: Whether to use kmeans to initialize the codebooks. :type kmeans_init: bool :param kmeans_iters: Number of iterations used for kmeans initialization. :type kmeans_iters: int :param threshold_ema_dead_code: Threshold for dead code expiration.
Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
decode(codes: Tensor) → Tensor
Decode the given codes to the quantized representation.
encode(x: Tensor, sample_rate: int, bandwidth: float | None = None, st: int | None = None) → Tensor
Encode a given input tensor with the specified sample rate at the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer.
forward(x: Tensor, sample_rate: int, bandwidth: float | None = None) → QuantizedResult
Residual vector quantization on the given input tensor. :param x: Input tensor. :type x: torch.Tensor :param sample_rate: Sample rate of the input tensor. :type sample_rate: int :param bandwidth: Target bandwidth. :type bandwidth: float
- Returns: The quantized (or approximately quantized) representation with the associated bandwidth and any penalty term for the loss.
- Return type:QuantizedResult
get_bandwidth_per_quantizer(sample_rate: int)
Return bandwidth per quantizer for a given input sample rate.
get_num_quantizers_for_bandwidth(sample_rate: int, bandwidth: float | None = None) → int
Return n_q based on specified target bandwidth.