espnet2.gan_tts.vits.loss.KLDivergenceLoss
Less than 1 minute
espnet2.gan_tts.vits.loss.KLDivergenceLoss
class espnet2.gan_tts.vits.loss.KLDivergenceLoss(*args, **kwargs)
Bases: Module
KL divergence loss.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(z_p: Tensor, logs_q: Tensor, m_p: Tensor, logs_p: Tensor, z_mask: Tensor) → Tensor
Calculate KL divergence loss.
- Parameters:
- z_p (Tensor) – Flow hidden representation (B, H, T_feats).
- logs_q (Tensor) – Posterior encoder projected scale (B, H, T_feats).
- m_p (Tensor) – Expanded text encoder projected mean (B, H, T_feats).
- logs_p (Tensor) – Expanded text encoder projected scale (B, H, T_feats).
- z_mask (Tensor) – Mask tensor (B, 1, T_feats).
- Returns: KL divergence loss.
- Return type: Tensor