Source code for espnet2.s2st.losses.tacotron_loss

import torch
from typeguard import typechecked

from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet2.utils.types import str2bool
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2Loss

[docs]class S2STTacotron2Loss(AbsS2STLoss): """Tacotron-based loss for S2ST.""" @typechecked def __init__( self, weight: float = 1.0, loss_type: str = "L1+L2", use_masking: str2bool = True, use_weighted_masking: str2bool = False, bce_pos_weight: float = 20.0, ): super().__init__() self.weight = weight self.loss_type = loss_type self.loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, )
[docs] def forward( self, after_outs: torch.Tensor, before_outs: torch.Tensor, logits: torch.Tensor, ys: torch.Tensor, labels: torch.Tensor, olens: torch.Tensor, ): """Forward. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). logits (Tensor): Batch of stop logits (B, Lmax). ys (Tensor): Batch of padded target features (B, Lmax, odim). labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ if self.weight > 0: l1_loss, mse_loss, bce_loss = self.loss( after_outs, before_outs, logits, ys, labels, olens ) if self.loss_type == "L1+L2": return l1_loss + mse_loss + bce_loss, l1_loss, mse_loss, bce_loss elif self.loss_type == "L1": return l1_loss + bce_loss, l1_loss, mse_loss, bce_loss elif self.loss_type == "L2": return mse_loss + bce_loss, l1_loss, mse_loss, bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") else: return None, None, None, None