espnet2.beats.tokenizer.NormEMAVectorQuantizer
Less than 1 minute
espnet2.beats.tokenizer.NormEMAVectorQuantizer
class espnet2.beats.tokenizer.NormEMAVectorQuantizer(n_embed, embedding_dim, beta, decay=0.99, kmeans_init=False, eps=1e-05, statistic_code_usage=True)
Bases: Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(z)
Encode the input with the vector quantizer.
- Parameters:z – (B, T, D) input tensor
- Returns: (B, T, D) quantized tensor loss: scalar quantization loss encoding_indices: (B, T) indices of the quantized embeddings
- Return type: z_q
reset_cluster_size(device)
