espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty
Less than 1 minute
espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty
class espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty(discriminator: AbsDiscriminator, weight: float = 1.0, probabilistic_grad_penalty_slicing: str2bool = False, reduction: str = 'sum')
Bases: AbsUASRLoss
gradient penalty for UASR.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(fake_sample: Tensor, real_sample: Tensor, is_training: str2bool, is_discrimininative_step: str2bool)
Forward.
- Parameters:
- fake_sample – generated sample from generator
- real_sample – real sample
- is_training – whether is at training step
- is_discriminative_step – whether is training discriminator