Source code for espnet2.spk.loss.aamsoftmax_subcenter_intertopk

# code from WeSpeaker: https://github.com/wenet-e2e/wespeaker/blob/
# c9ec537b53fe1e04525be74b2550ee95bed3a891/wespeaker/models/projections.py#L243

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from espnet2.spk.loss.abs_loss import AbsLoss


[docs]class ArcMarginProduct_intertopk_subcenter(AbsLoss): r"""Implement of large margin arc distance with intertopk and subcenter: Reference: MULTI-QUERY MULTI-HEAD ATTENTION POOLING AND INTER-TOPK PENALTY FOR SPEAKER VERIFICATION. https://arxiv.org/pdf/2110.05042.pdf Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web Faces. https://ibug.doc.ic.ac.uk/media/uploads/documents/eccv_1445.pdf Args: in_features: size of each input sample out_features: size of each output sample scale: norm of input feature margin: margin cos(theta + margin) K: number of sub-centers k_top: number of hard samples mp: margin penalty of hard samples do_lm: whether do large margin finetune """ def __init__( self, nout, nclasses, scale=32.0, margin=0.2, easy_margin=False, K=3, mp=0.06, k_top=5, do_lm=False, ): super().__init__(nout) self.in_features = nout self.out_features = nclasses self.scale = scale self.margin = margin self.do_lm = do_lm # intertopk + subcenter self.K = K if do_lm: # if do LMF, remove hard sample penalty self.mp = 0.0 self.k_top = 0 else: self.mp = mp self.k_top = k_top # initial classifier self.weight = nn.Parameter(torch.FloatTensor(self.K * nclasses, nout)) nn.init.xavier_uniform_(self.weight) self.easy_margin = easy_margin self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.th = math.cos(math.pi - margin) self.mm = math.sin(math.pi - margin) * margin self.mmm = 1.0 + math.cos( math.pi - margin ) # this can make the output more continuous ######## self.m = self.margin ######## self.cos_mp = math.cos(0.0) self.sin_mp = math.sin(0.0) self.ce = nn.CrossEntropyLoss()
[docs] def update(self, margin=0.2): self.margin = margin self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.th = math.cos(math.pi - margin) self.mm = math.sin(math.pi - margin) * margin self.m = self.margin self.mmm = 1.0 + math.cos(math.pi - margin) # hard sample margin is increasing as margin if margin > 0.001: mp = self.mp * (margin / 0.2) else: mp = 0.0 self.cos_mp = math.cos(mp) self.sin_mp = math.sin(mp)
[docs] def forward(self, input, label): if len(label.size()) == 2: label = label.squeeze(1) cosine = F.linear( F.normalize(input), F.normalize(self.weight) ) # (batch, out_dim * k) cosine = torch.reshape( cosine, (-1, self.out_features, self.K) ) # (batch, out_dim, k) cosine, _ = torch.max(cosine, 2) # (batch, out_dim) sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m phi_mp = cosine * self.cos_mp + sine * self.sin_mp if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: ######## # phi = torch.where(cosine > self.th, phi, cosine - self.mm) phi = torch.where(cosine > self.th, phi, cosine - self.mmm) ######## one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, label.view(-1, 1), 1) if self.k_top > 0: # topk (j != y_i) _, top_k_index = torch.topk( cosine - 2 * one_hot, self.k_top ) # exclude j = y_i top_k_one_hot = input.new_zeros(cosine.size()).scatter_(1, top_k_index, 1) # sum output = ( (one_hot * phi) + (top_k_one_hot * phi_mp) + ((1.0 - one_hot - top_k_one_hot) * cosine) ) else: output = (one_hot * phi) + ((1.0 - one_hot) * cosine) output *= self.scale loss = self.ce(output, label) return loss