espnet2.spk.espnet_model.ESPnetSpeakerModel
espnet2.spk.espnet_model.ESPnetSpeakerModel
class espnet2.spk.espnet_model.ESPnetSpeakerModel(frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, encoder: AbsEncoder | None, pooling: AbsPooling | None, projector: AbsProjector | None, loss: AbsLoss | None)
Bases: AbsESPnetModel
Speaker embedding extraction model.
Core model for diverse speaker-related tasks (e.g., verification, open-set identification, diarization)
The model architecture comprises mainly ‘encoder’, ‘pooling’, and ‘projector’. In common speaker recognition field, the combination of three would be usually named as ‘speaker_encoder’ (or speaker embedding extractor). We splitted it into three for flexibility in future extensions:
- ‘encoder’ : extract frame-level speaker embeddings.
- ‘pooling’ : aggregate into single utterance-level embedding.
- ‘projector’ : connected layer) to derive speaker embedding.
Possibly, in the future, ‘pooling’ and/or ‘projector’ can be integrated as a ‘decoder’, depending on the extension for joint usage of different tasks (e.g., ASR, SE, target speaker extraction).
Initializes internal Module state, shared by both nn.Module and ScriptModule.
aggregate(frame_level_feats: Tensor) → Tensor
collect_feats(speech: Tensor, speech_lengths: Tensor, spk_labels: Tensor | None = None, **kwargs) → Dict[str, Tensor]
encode_frame(feats: Tensor) → Tensor
extract_feats(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
forward(speech: Tensor, spk_labels: Tensor | None = None, task_tokens: Tensor | None = None, extract_embd: bool = False, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor] | Tensor
Feed-forward through encoder layers and aggregate into utterance-level
feature.
- Parameters:
- speech – (Batch, samples)
- speech_lengths – (Batch,)
- extract_embd – a flag which doesn’t go through the classification head when set True
- spk_labels – (Batch, )
- phase (one-hot speaker labels used in the train)
- task_tokens – (Batch, )
- trainings (task tokens used in case of token-based)
project_spk_embd(utt_level_feat: Tensor) → Tensor