Source code for espnet2.asvspoof.loss.oc_softmax_loss

import torch

from espnet2.asvspoof.loss.abs_loss import AbsASVSpoofLoss


[docs]class ASVSpoofOCSoftmaxLoss(AbsASVSpoofLoss): """Binary loss for ASV Spoofing.""" def __init__( self, weight: float = 1.0, enc_dim: int = 128, m_real: float = 0.5, m_fake: float = 0.2, alpha: float = 20.0, ): super(ASVSpoofOCSoftmaxLoss).__init__() self.weight = weight self.feat_dim = enc_dim self.m_real = m_real self.m_fake = m_fake self.alpha = alpha self.center = torch.nn.Parameter(torch.randn(1, self.feat_dim)) torch.nn.init.kaiming_uniform_(self.center, 0.25) self.softplus = torch.nn.Softplus()
[docs] def forward(self, label: torch.Tensor, emb: torch.Tensor, **kwargs): """Forward. Args: label (torch.Tensor): ground truth label [Batch, 1] emb (torch.Tensor): encoder embedding output [Batch, T, enc_dim] """ emb = torch.mean(emb, dim=1) w = torch.nn.functional.normalize(self.center, p=2, dim=1) # noqa x = torch.nn.functional.normalize(emb, p=2, dim=1) # noqa # TODO(exercise 2): compute scores based on w and x # TODO(exercise 2): calculate the score bias based on m_real and m_fake # TODO(exercise 2): apply alpha and softplus # TODO(exercise 2): returnthe final loss return None
[docs] def score(self, emb: torch.Tensor): """Prediction. Args: emb (torch.Tensor): encoder embedding output [Batch, T, enc_dim] """ emb = torch.mean(emb, dim=1) w = torch.nn.functional.normalize(self.center, p=2, dim=1) # noqa x = torch.nn.functional.normalize(emb, p=2, dim=1) # noqa
# TODO(exercise 2): compute scores