espnet2.gan_tts.jets.loss.VarianceLoss
Less than 1 minute
espnet2.gan_tts.jets.loss.VarianceLoss
class espnet2.gan_tts.jets.loss.VarianceLoss(use_masking: bool = True, use_weighted_masking: bool = False)
Bases: Module
Initialize JETS variance loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to weighted masking in loss calculation.
forward(d_outs: Tensor, ds: Tensor, p_outs: Tensor, ps: Tensor, e_outs: Tensor, es: Tensor, ilens: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
Calculate forward propagation.
- Parameters:
- d_outs (LongTensor) – Batch of outputs of duration predictor (B, T_text).
- ds (LongTensor) – Batch of durations (B, T_text).
- p_outs (Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- ps (Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- e_outs (Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- es (Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (LongTensor) – Batch of the lengths of each input (B,).
- Returns: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value.
- Return type: Tensor