Source code for espnet.nets.pytorch_backend.transducer.transducer_tasks

"""Module implementing Transducer main and auxiliary tasks."""

from typing import Any, List, Optional, Tuple

import torch

from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (  # noqa: H301
    LabelSmoothingLoss,
)


[docs]class TransducerTasks(torch.nn.Module): """Transducer tasks module.""" def __init__( self, encoder_dim: int, decoder_dim: int, joint_dim: int, output_dim: int, joint_activation_type: str = "tanh", transducer_loss_weight: float = 1.0, ctc_loss: bool = False, ctc_loss_weight: float = 0.5, ctc_loss_dropout_rate: float = 0.0, lm_loss: bool = False, lm_loss_weight: float = 0.5, lm_loss_smoothing_rate: float = 0.0, aux_transducer_loss: bool = False, aux_transducer_loss_weight: float = 0.2, aux_transducer_loss_mlp_dim: int = 320, aux_trans_loss_mlp_dropout_rate: float = 0.0, symm_kl_div_loss: bool = False, symm_kl_div_loss_weight: float = 0.2, fastemit_lambda: float = 0.0, blank_id: int = 0, ignore_id: int = -1, training: bool = False, ): """Initialize module for Transducer tasks. Args: encoder_dim: Encoder outputs dimension. decoder_dim: Decoder outputs dimension. joint_dim: Joint space dimension. output_dim: Output dimension. joint_activation_type: Type of activation for joint network. transducer_loss_weight: Weight for main transducer loss. ctc_loss: Compute CTC loss. ctc_loss_weight: Weight of CTC loss. ctc_loss_dropout_rate: Dropout rate for CTC loss inputs. lm_loss: Compute LM loss. lm_loss_weight: Weight of LM loss. lm_loss_smoothing_rate: Smoothing rate for LM loss' label smoothing. aux_transducer_loss: Compute auxiliary transducer loss. aux_transducer_loss_weight: Weight of auxiliary transducer loss. aux_transducer_loss_mlp_dim: Hidden dimension for aux. transducer MLP. aux_trans_loss_mlp_dropout_rate: Dropout rate for aux. transducer MLP. symm_kl_div_loss: Compute KL divergence loss. symm_kl_div_loss_weight: Weight of KL divergence loss. fastemit_lambda: Regularization parameter for FastEmit. blank_id: Blank symbol ID. ignore_id: Padding symbol ID. training: Whether the model was initializated in training or inference mode. """ super().__init__() if not training: ctc_loss, lm_loss, aux_transducer_loss, symm_kl_div_loss = ( False, False, False, False, ) self.joint_network = JointNetwork( output_dim, encoder_dim, decoder_dim, joint_dim, joint_activation_type ) if training: from warprnnt_pytorch import RNNTLoss self.transducer_loss = RNNTLoss( blank=blank_id, reduction="sum", fastemit_lambda=fastemit_lambda, ) if ctc_loss: self.ctc_lin = torch.nn.Linear(encoder_dim, output_dim) self.ctc_loss = torch.nn.CTCLoss( blank=blank_id, reduction="none", zero_infinity=True, ) if aux_transducer_loss: self.mlp = torch.nn.Sequential( torch.nn.Linear(encoder_dim, aux_transducer_loss_mlp_dim), torch.nn.LayerNorm(aux_transducer_loss_mlp_dim), torch.nn.Dropout(p=aux_trans_loss_mlp_dropout_rate), torch.nn.ReLU(), torch.nn.Linear(aux_transducer_loss_mlp_dim, joint_dim), ) if symm_kl_div_loss: self.kl_div = torch.nn.KLDivLoss(reduction="sum") if lm_loss: self.lm_lin = torch.nn.Linear(decoder_dim, output_dim) self.label_smoothing_loss = LabelSmoothingLoss( output_dim, ignore_id, lm_loss_smoothing_rate, normalize_length=False ) self.output_dim = output_dim self.transducer_loss_weight = transducer_loss_weight self.use_ctc_loss = ctc_loss self.ctc_loss_weight = ctc_loss_weight self.ctc_dropout_rate = ctc_loss_dropout_rate self.use_lm_loss = lm_loss self.lm_loss_weight = lm_loss_weight self.use_aux_transducer_loss = aux_transducer_loss self.aux_transducer_loss_weight = aux_transducer_loss_weight self.use_symm_kl_div_loss = symm_kl_div_loss self.symm_kl_div_loss_weight = symm_kl_div_loss_weight self.blank_id = blank_id self.ignore_id = ignore_id self.target = None
[docs] def compute_transducer_loss( self, enc_out: torch.Tensor, dec_out: torch.tensor, target: torch.Tensor, t_len: torch.Tensor, u_len: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Transducer loss. Args: enc_out: Encoder output sequences. (B, T, D_enc) dec_out: Decoder output sequences. (B, U, D_dec) target: Target label ID sequences. (B, L) t_len: Time lengths. (B,) u_len: Label lengths. (B,) Returns: (joint_out, loss_trans): Joint output sequences. (B, T, U, D_joint), Transducer loss value. """ joint_out = self.joint_network(enc_out.unsqueeze(2), dec_out.unsqueeze(1)) loss_trans = self.transducer_loss(joint_out, target, t_len, u_len) loss_trans /= joint_out.size(0) return joint_out, loss_trans
[docs] def compute_ctc_loss( self, enc_out: torch.Tensor, target: torch.Tensor, t_len: torch.Tensor, u_len: torch.Tensor, ): """Compute CTC loss. Args: enc_out: Encoder output sequences. (B, T, D_enc) target: Target character ID sequences. (B, U) t_len: Time lengths. (B,) u_len: Label lengths. (B,) Returns: : CTC loss value. """ ctc_lin = self.ctc_lin( torch.nn.functional.dropout( enc_out.to(dtype=torch.float32), p=self.ctc_dropout_rate ) ) ctc_logp = torch.log_softmax(ctc_lin.transpose(0, 1), dim=-1) with torch.backends.cudnn.flags(deterministic=True): loss_ctc = self.ctc_loss(ctc_logp, target, t_len, u_len) return loss_ctc.mean()
[docs] def compute_aux_transducer_and_symm_kl_div_losses( self, aux_enc_out: torch.Tensor, dec_out: torch.Tensor, joint_out: torch.Tensor, target: torch.Tensor, aux_t_len: torch.Tensor, u_len: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute auxiliary Transducer loss and Jensen-Shannon divergence loss. Args: aux_enc_out: Encoder auxiliary output sequences. [N x (B, T_aux, D_enc_aux)] dec_out: Decoder output sequences. (B, U, D_dec) joint_out: Joint output sequences. (B, T, U, D_joint) target: Target character ID sequences. (B, L) aux_t_len: Auxiliary time lengths. [N x (B,)] u_len: True U lengths. (B,) Returns: : Auxiliary Transducer loss and KL divergence loss values. """ aux_trans_loss = 0 symm_kl_div_loss = 0 num_aux_layers = len(aux_enc_out) B, T, U, D = joint_out.shape for p in self.joint_network.parameters(): p.requires_grad = False for i, aux_enc_out_i in enumerate(aux_enc_out): aux_mlp = self.mlp(aux_enc_out_i) aux_joint_out = self.joint_network( aux_mlp.unsqueeze(2), dec_out.unsqueeze(1), is_aux=True, ) if self.use_aux_transducer_loss: aux_trans_loss += ( self.transducer_loss( aux_joint_out, target, aux_t_len[i], u_len, ) / B ) if self.use_symm_kl_div_loss: denom = B * T * U kl_main_aux = ( self.kl_div( torch.log_softmax(joint_out, dim=-1), torch.softmax(aux_joint_out, dim=-1), ) / denom ) kl_aux_main = ( self.kl_div( torch.log_softmax(aux_joint_out, dim=-1), torch.softmax(joint_out, dim=-1), ) / denom ) symm_kl_div_loss += kl_main_aux + kl_aux_main for p in self.joint_network.parameters(): p.requires_grad = True aux_trans_loss /= num_aux_layers if self.use_symm_kl_div_loss: symm_kl_div_loss /= num_aux_layers return aux_trans_loss, symm_kl_div_loss
[docs] def compute_lm_loss( self, dec_out: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """Forward LM loss. Args: dec_out: Decoder output sequences. (B, U, D_dec) target: Target label ID sequences. (B, U) Returns: : LM loss value. """ lm_lin = self.lm_lin(dec_out) lm_loss = self.label_smoothing_loss(lm_lin, target) return lm_loss
[docs] def set_target(self, target: torch.Tensor): """Set target label ID sequences. Args: target: Target label ID sequences. (B, L) """ self.target = target
[docs] def get_target(self): """Set target label ID sequences. Args: Returns: target: Target label ID sequences. (B, L) """ return self.target
[docs] def get_transducer_tasks_io( self, labels: torch.Tensor, enc_out_len: torch.Tensor, aux_enc_out_len: Optional[List], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get Transducer tasks inputs and outputs. Args: labels: Label ID sequences. (B, U) enc_out_len: Time lengths. (B,) aux_enc_out_len: Auxiliary time lengths. [N X (B,)] Returns: target: Target label ID sequences. (B, L) lm_loss_target: LM loss target label ID sequences. (B, U) t_len: Time lengths. (B,) aux_t_len: Auxiliary time lengths. [N x (B,)] u_len: Label lengths. (B,) """ device = labels.device labels_unpad = [label[label != self.ignore_id] for label in labels] blank = labels[0].new([self.blank_id]) target = pad_list(labels_unpad, self.blank_id).type(torch.int32).to(device) lm_loss_target = ( pad_list( [torch.cat([y, blank], dim=0) for y in labels_unpad], self.ignore_id ) .type(torch.int64) .to(device) ) self.set_target(target) if enc_out_len.dim() > 1: enc_mask_unpad = [m[m != 0] for m in enc_out_len] enc_out_len = list(map(int, [m.size(0) for m in enc_mask_unpad])) else: enc_out_len = list(map(int, enc_out_len)) t_len = torch.IntTensor(enc_out_len).to(device) u_len = torch.IntTensor([label.size(0) for label in labels_unpad]).to(device) if aux_enc_out_len: aux_t_len = [] for i in range(len(aux_enc_out_len)): if aux_enc_out_len[i].dim() > 1: aux_mask_unpad = [aux[aux != 0] for aux in aux_enc_out_len[i]] aux_t_len.append( torch.IntTensor( list(map(int, [aux.size(0) for aux in aux_mask_unpad])) ).to(device) ) else: aux_t_len.append( torch.IntTensor(list(map(int, aux_enc_out_len[i]))).to(device) ) else: aux_t_len = aux_enc_out_len return target, lm_loss_target, t_len, aux_t_len, u_len
[docs] def forward( self, enc_out: torch.Tensor, aux_enc_out: List[torch.Tensor], dec_out: torch.Tensor, labels: torch.Tensor, enc_out_len: torch.Tensor, aux_enc_out_len: torch.Tensor, ) -> Tuple[Tuple[Any], float, float]: """Forward main and auxiliary task. Args: enc_out: Encoder output sequences. (B, T, D_enc) aux_enc_out: Encoder intermediate output sequences. (B, T_aux, D_enc_aux) dec_out: Decoder output sequences. (B, U, D_dec) target: Target label ID sequences. (B, L) t_len: Time lengths. (B,) aux_t_len: Auxiliary time lengths. (B,) u_len: Label lengths. (B,) Returns: : Weighted losses. (transducer loss, ctc loss, aux Transducer loss, KL div loss, LM loss) cer: Sentence-level CER score. wer: Sentence-level WER score. """ if self.use_symm_kl_div_loss: assert self.use_aux_transducer_loss (trans_loss, ctc_loss, lm_loss, aux_trans_loss, symm_kl_div_loss) = ( 0.0, 0.0, 0.0, 0.0, 0.0, ) target, lm_loss_target, t_len, aux_t_len, u_len = self.get_transducer_tasks_io( labels, enc_out_len, aux_enc_out_len, ) joint_out, trans_loss = self.compute_transducer_loss( enc_out, dec_out, target, t_len, u_len ) if self.use_ctc_loss: ctc_loss = self.compute_ctc_loss(enc_out, target, t_len, u_len) if self.use_aux_transducer_loss: ( aux_trans_loss, symm_kl_div_loss, ) = self.compute_aux_transducer_and_symm_kl_div_losses( aux_enc_out, dec_out, joint_out, target, aux_t_len, u_len, ) if self.use_lm_loss: lm_loss = self.compute_lm_loss(dec_out, lm_loss_target) return ( self.transducer_loss_weight * trans_loss, self.ctc_loss_weight * ctc_loss, self.aux_transducer_loss_weight * aux_trans_loss, self.symm_kl_div_loss_weight * symm_kl_div_loss, self.lm_loss_weight * lm_loss, )