Source code for espnet2.s2st.losses.guided_attention_loss

import torch
from typeguard import typechecked

from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss


[docs]class S2STGuidedAttentionLoss(AbsS2STLoss): """Tacotron-based loss for S2ST.""" @typechecked def __init__( self, weight: float = 1.0, sigma: float = 0.4, alpha: float = 1.0, ): super().__init__() self.weight = weight self.loss = GuidedAttentionLoss( sigma=sigma, alpha=alpha, )
[docs] def forward( self, att_ws: torch.Tensor, ilens: torch.Tensor, olens_in: torch.Tensor, ): """Forward. Args: Returns: Tensor: guided attention loss """ if self.weight > 0: return self.loss(att_ws, ilens, olens_in) else: return None, None, None, None