Source code for espnet2.s2st.losses.attention_loss

import torch
from typeguard import typechecked

from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet2.utils.types import str2bool
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (  # noqa: H301
    LabelSmoothingLoss,
)


[docs]class S2STAttentionLoss(AbsS2STLoss): """attention-based label smoothing loss for S2ST.""" @typechecked def __init__( self, vocab_size: int, padding_idx: int = -1, weight: float = 1.0, smoothing: float = 0.0, normalize_length: str2bool = False, criterion: torch.nn.Module = torch.nn.KLDivLoss(reduction="none"), ): super().__init__() self.weight = weight self.loss = LabelSmoothingLoss( size=vocab_size, padding_idx=padding_idx, smoothing=smoothing, normalize_length=normalize_length, criterion=criterion, )
[docs] def forward( self, dense_y: torch.Tensor, token_y: torch.Tensor, ): """Forward. Args: """ if self.weight > 0: return self.loss(dense_y, token_y) else: return None