espnet.nets.pytorch_backend.transducer.transducer_tasks.TransducerTasks
espnet.nets.pytorch_backend.transducer.transducer_tasks.TransducerTasks
class espnet.nets.pytorch_backend.transducer.transducer_tasks.TransducerTasks(encoder_dim: int, decoder_dim: int, joint_dim: int, output_dim: int, joint_activation_type: str = 'tanh', transducer_loss_weight: float = 1.0, ctc_loss: bool = False, ctc_loss_weight: float = 0.5, ctc_loss_dropout_rate: float = 0.0, lm_loss: bool = False, lm_loss_weight: float = 0.5, lm_loss_smoothing_rate: float = 0.0, aux_transducer_loss: bool = False, aux_transducer_loss_weight: float = 0.2, aux_transducer_loss_mlp_dim: int = 320, aux_trans_loss_mlp_dropout_rate: float = 0.0, symm_kl_div_loss: bool = False, symm_kl_div_loss_weight: float = 0.2, fastemit_lambda: float = 0.0, blank_id: int = 0, ignore_id: int = -1, training: bool = False)
Bases: Module
Transducer tasks module.
Initialize module for Transducer tasks.
- Parameters:
- encoder_dim – Encoder outputs dimension.
- decoder_dim – Decoder outputs dimension.
- joint_dim – Joint space dimension.
- output_dim – Output dimension.
- joint_activation_type – Type of activation for joint network.
- transducer_loss_weight – Weight for main transducer loss.
- ctc_loss – Compute CTC loss.
- ctc_loss_weight – Weight of CTC loss.
- ctc_loss_dropout_rate – Dropout rate for CTC loss inputs.
- lm_loss – Compute LM loss.
- lm_loss_weight – Weight of LM loss.
- lm_loss_smoothing_rate – Smoothing rate for LM loss’ label smoothing.
- aux_transducer_loss – Compute auxiliary transducer loss.
- aux_transducer_loss_weight – Weight of auxiliary transducer loss.
- aux_transducer_loss_mlp_dim – Hidden dimension for aux. transducer MLP.
- aux_trans_loss_mlp_dropout_rate – Dropout rate for aux. transducer MLP.
- symm_kl_div_loss – Compute KL divergence loss.
- symm_kl_div_loss_weight – Weight of KL divergence loss.
- fastemit_lambda – Regularization parameter for FastEmit.
- blank_id – Blank symbol ID.
- ignore_id – Padding symbol ID.
- training – Whether the model was initializated in training or inference mode.
compute_aux_transducer_and_symm_kl_div_losses(aux_enc_out: Tensor, dec_out: Tensor, joint_out: Tensor, target: Tensor, aux_t_len: Tensor, u_len: Tensor) → Tuple[Tensor, Tensor]
Compute auxiliary Transducer loss and Jensen-Shannon divergence loss.
- Parameters:
- aux_enc_out – Encoder auxiliary output sequences. [N x (B, T_aux, D_enc_aux)]
- dec_out – Decoder output sequences. (B, U, D_dec)
- joint_out – Joint output sequences. (B, T, U, D_joint)
- target – Target character ID sequences. (B, L)
- aux_t_len – Auxiliary time lengths. [N x (B,)]
- u_len – True U lengths. (B,)
- Returns: Auxiliary Transducer loss and KL divergence loss values.
compute_ctc_loss(enc_out: Tensor, target: Tensor, t_len: Tensor, u_len: Tensor)
Compute CTC loss.
- Parameters:
- enc_out – Encoder output sequences. (B, T, D_enc)
- target – Target character ID sequences. (B, U)
- t_len – Time lengths. (B,)
- u_len – Label lengths. (B,)
- Returns: CTC loss value.
compute_lm_loss(dec_out: Tensor, target: Tensor) → Tensor
Forward LM loss.
- Parameters:
- dec_out – Decoder output sequences. (B, U, D_dec)
- target – Target label ID sequences. (B, U)
- Returns: LM loss value.
compute_transducer_loss(enc_out: Tensor, dec_out: tensor, target: Tensor, t_len: Tensor, u_len: Tensor) → Tuple[Tensor, Tensor]
Compute Transducer loss.
- Parameters:
- enc_out – Encoder output sequences. (B, T, D_enc)
- dec_out – Decoder output sequences. (B, U, D_dec)
- target – Target label ID sequences. (B, L)
- t_len – Time lengths. (B,)
- u_len – Label lengths. (B,)
- Returns: Joint output sequences. (B, T, U, D_joint), Transducer loss value.
- Return type: (joint_out, loss_trans)
forward(enc_out: Tensor, aux_enc_out: List[Tensor], dec_out: Tensor, labels: Tensor, enc_out_len: Tensor, aux_enc_out_len: Tensor) → Tuple[Tuple[Any], float, float]
Forward main and auxiliary task.
Parameters:
- enc_out – Encoder output sequences. (B, T, D_enc)
- aux_enc_out – Encoder intermediate output sequences. (B, T_aux, D_enc_aux)
- dec_out – Decoder output sequences. (B, U, D_dec)
- target – Target label ID sequences. (B, L)
- t_len – Time lengths. (B,)
- aux_t_len – Auxiliary time lengths. (B,)
- u_len – Label lengths. (B,)
Returns: Weighted losses. : (transducer loss, ctc loss, aux Transducer loss, KL div loss, LM loss)
cer: Sentence-level CER score. wer: Sentence-level WER score.
get_target()
Set target label ID sequences.
Args:
- Returns: Target label ID sequences. (B, L)
- Return type: target
get_transducer_tasks_io(labels: Tensor, enc_out_len: Tensor, aux_enc_out_len: List | None) → Tuple[Tensor, Tensor, Tensor, Tensor]
Get Transducer tasks inputs and outputs.
- Parameters:
- labels – Label ID sequences. (B, U)
- enc_out_len – Time lengths. (B,)
- aux_enc_out_len – Auxiliary time lengths. [N X (B,)]
- Returns: Target label ID sequences. (B, L) lm_loss_target: LM loss target label ID sequences. (B, U) t_len: Time lengths. (B,) aux_t_len: Auxiliary time lengths. [N x (B,)] u_len: Label lengths. (B,)
- Return type: target
set_target(target: Tensor)
Set target label ID sequences.
- Parameters:target – Target label ID sequences. (B, L)