espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM
About 1 min
espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM
class espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM(vocab_size: int, nq: int, share_emb: bool = True, g_att_unit: int = 256, g_head: int = 2, g_layer: int = 4, l_att_unit: int = 256, l_head: int = 2, l_layer: int = 4, n_ctx: int = 3000, first_layer_weight: int = 1.0)
Bases: AbsCoreLM
Initialize MultiScaleLM
- Parameters:
- vocab_size (int) – Dimention of vocabulary.
- nq (int) – Number of codes for each token / frame, usually for speech codec.
- share_emb (bool) – If true, share the embedding and lm_head weight.
- g_att_unit (int) – Dimention of global Transformer attention.
- g_head (int) – Number of heads in global Transformer attention.
- g_layer (int) – Number of layers in global Transformer.
- l_att_unit (int) – Dimention of local Transformer attention.
- l_head (int) – Number of heads in local Transformer attention.
- l_layer (int) – Number of layers in local Transformer.
- n_ctx (int) – maximum context length of global Transformer.
- first_layer_weight (int) – a factor to scale the gradient for the first-layer codes.
forward(dec_seq: Tensor, dec_seq_lengths: Tensor | None = None, enc_seq: Tensor | None = None, enc_seq_lengths: Tensor | None = None, prefix_len: Tensor | None = None) → Tuple[Tensor, Dict, Tensor]
Auto-Regresive MultiScale forward for training
- Parameters:
- dec_seq (LongTensor) – Batch of decoder sequences (B, T, nq).
- dec_seq_lengths (LongTensor) – Lengths of batched decoder sequences (B,).
- enc_seq (LongTensor) – Batch of encoder sequences (B, T, nq), keep the interface, may not be used.
- enc_seq_lengths (LongTensor) – Lengths of batched encoder sequences (B,), keep the interface, may not be used.
- prefix_len (LongTensor) – Lengths of condition part in dec_seq (B,).
inference(prefix: Tensor, opts: SpeechLMInferenceOptions, enc_seq: Tensor = None, suffix: Tensor = None)
Auto-Regresive MultiScale Inference.
- Parameters:
- prefix (LongTensor) – Prefix part of dec_seq (B, T_dec, nq).
- opts (SpeechLMInferenceOptions) – inference options.
- enc_seq (LongTensor) – Encoder token sequence (B, T_enc, nq).
- suffix (LongTensor) – suffix part of dec_seq (B, T_dec, nq), usually the target sequence for teacher-forcing.