import torch
from typeguard import typechecked
from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet2.utils.types import str2bool
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2Loss
[docs]class S2STTacotron2Loss(AbsS2STLoss):
"""Tacotron-based loss for S2ST."""
@typechecked
def __init__(
self,
weight: float = 1.0,
loss_type: str = "L1+L2",
use_masking: str2bool = True,
use_weighted_masking: str2bool = False,
bce_pos_weight: float = 20.0,
):
super().__init__()
self.weight = weight
self.loss_type = loss_type
self.loss = Tacotron2Loss(
use_masking=use_masking,
use_weighted_masking=use_weighted_masking,
bce_pos_weight=bce_pos_weight,
)
[docs] def forward(
self,
after_outs: torch.Tensor,
before_outs: torch.Tensor,
logits: torch.Tensor,
ys: torch.Tensor,
labels: torch.Tensor,
olens: torch.Tensor,
):
"""Forward.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
logits (Tensor): Batch of stop logits (B, Lmax).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
Tensor: Binary cross entropy loss value.
"""
if self.weight > 0:
l1_loss, mse_loss, bce_loss = self.loss(
after_outs, before_outs, logits, ys, labels, olens
)
if self.loss_type == "L1+L2":
return l1_loss + mse_loss + bce_loss, l1_loss, mse_loss, bce_loss
elif self.loss_type == "L1":
return l1_loss + bce_loss, l1_loss, mse_loss, bce_loss
elif self.loss_type == "L2":
return mse_loss + bce_loss, l1_loss, mse_loss, bce_loss
else:
raise ValueError(f"unknown --loss-type {self.loss_type}")
else:
return None, None, None, None