espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter
espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter
class espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter(nout: int, nclasses: int, scale: float = 32.0, margin: float = 0.2, easy_margin: bool = False, K: int = 3, mp: float = 0.06, k_top: int = 5, do_lm: bool = False)
Bases: AbsLoss
ArcFace loss (AAMSoftmax loss) with Inter-TopK penalty and Sub-center.
This loss function combines three techniques:
- ArcFace: Additive angular margin loss for better feature discrimination
- Sub-center: Multiple prototypes per class to handle intra-class variation
- Inter-TopK: Additional penalty on hardest negative samples
Reference: : Multi-Query Multi-Head Attention Pooling and Inter-TopK Penalty for Speaker Verification https://arxiv.org/pdf/2110.05042.pdf Sub-Center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web Faces. https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf
- Parameters:
- nout – Dimension of input features (embedding size)
- nclasses – Number of output classes
- scale – Feature scaling factor
- margin – Angular margin for positive samples
- easy_margin – Whether to use easy margin variant
- K – Number of sub-centers per class
- mp – Margin penalty for hard negative samples
- k_top – Number of hardest negative samples to penalize
- do_lm – Whether to enable Large Margin Fine-tuning mode
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, label: Tensor | None = None) → Tuple[Tensor | None, Tensor | None, Tensor]
Forward pass of ArcFace (AAMSoftmax) with sub-center and inter-topk penalty.
- Parameters:
- input – Input embeddings, shape (batch_size, embedding_dim)
- label – Ground truth labels, shape (batch_size,)
- Returns: Cross-entropy loss with angular margins accuracy: Classification accuracy pred_lids: Predicted class indices
- Return type: loss
update(margin: float = 0.2)
Update margin and related trigonometric values during training.