# Copyright 2022 Dan Lim
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""JETS related loss module for ESPnet2."""
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import typechecked
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( # noqa: H301
DurationPredictorLoss,
)
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
[docs]class VarianceLoss(torch.nn.Module):
@typechecked
def __init__(self, use_masking: bool = True, use_weighted_masking: bool = False):
"""Initialize JETS variance loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss
calculation.
use_weighted_masking (bool): Whether to weighted masking in loss
calculation.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
[docs] def forward(
self,
d_outs: torch.Tensor,
ds: torch.Tensor,
p_outs: torch.Tensor,
ps: torch.Tensor,
e_outs: torch.Tensor,
es: torch.Tensor,
ilens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
ds (LongTensor): Batch of durations (B, T_text).
p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
ilens (LongTensor): Batch of the lengths of each input (B,).
Returns:
Tensor: Duration predictor loss value.
Tensor: Pitch predictor loss value.
Tensor: Energy predictor loss value.
"""
# apply mask to remove padded part
if self.use_masking:
duration_masks = make_non_pad_mask(ilens).to(ds.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device)
p_outs = p_outs.masked_select(pitch_masks)
e_outs = e_outs.masked_select(pitch_masks)
ps = ps.masked_select(pitch_masks)
es = es.masked_select(pitch_masks)
# calculate loss
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
# make weighted mask and apply it
if self.use_weighted_masking:
duration_masks = make_non_pad_mask(ilens).to(ds.device)
duration_weights = (
duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()
)
duration_weights /= ds.size(0)
# apply weight
duration_loss = (
duration_loss.mul(duration_weights).masked_select(duration_masks).sum()
)
pitch_masks = duration_masks.unsqueeze(-1)
pitch_weights = duration_weights.unsqueeze(-1)
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
energy_loss = (
energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
)
return duration_loss, pitch_loss, energy_loss
[docs]class ForwardSumLoss(torch.nn.Module):
"""Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi"""
def __init__(self):
"""Initialize forwardsum loss module."""
super().__init__()
[docs] def forward(
self,
log_p_attn: torch.Tensor,
ilens: torch.Tensor,
olens: torch.Tensor,
blank_prob: float = np.e**-1,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
log_p_attn (Tensor): Batch of log probability of attention matrix
(B, T_feats, T_text).
ilens (Tensor): Batch of the lengths of each input (B,).
olens (Tensor): Batch of the lengths of each target (B,).
blank_prob (float): Blank symbol probability.
Returns:
Tensor: forwardsum loss value.
"""
B = log_p_attn.size(0)
# a row must be added to the attention matrix to account for
# blank token of CTC loss
# (B,T_feats,T_text+1)
log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob))
loss = 0
for bidx in range(B):
# construct target sequnece.
# Every text token is mapped to a unique sequnece number.
target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0)
cur_log_p_attn_pd = log_p_attn_pd[
bidx, : olens[bidx], : ilens[bidx] + 1
].unsqueeze(
1
) # (T_feats,1,T_text+1)
cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1)
loss += F.ctc_loss(
log_probs=cur_log_p_attn_pd,
targets=target_seq,
input_lengths=olens[bidx : bidx + 1],
target_lengths=ilens[bidx : bidx + 1],
zero_infinity=True,
)
loss = loss / B
return loss