espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss
Less than 1 minute
espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss
class espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss(average_by_discriminators: bool = True, loss_type: str = 'mse')
Bases: Module
Discriminator adversarial loss module.
Initialize DiscriminatorAversarialLoss module.
- Parameters:
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- loss_type (str) – Loss type, “mse” or “hinge”.
forward(outputs_hat: List[List[Tensor]] | List[Tensor] | Tensor, outputs: List[List[Tensor]] | List[Tensor] | Tensor) → Tuple[Tensor, Tensor]
Calcualate discriminator adversarial loss.
- Parameters:
- outputs_hat (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from generator.
- outputs (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from groundtruth.
- Returns: Discriminator real loss value. Tensor: Discriminator fake loss value.
- Return type: Tensor