espnet2.ssl.espnet_model.ESPnetSSLModel
espnet2.ssl.espnet_model.ESPnetSSLModel
class espnet2.ssl.espnet_model.ESPnetSSLModel(frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, losses: List[AbsSSLLoss], util_attributes: Set[str], required_inputs: Set[str], util_modules: ModuleDict, token_list: Tuple[str, ...] | List[str] | None = None, extract_feats_in_collect_stats: bool = True, **kwargs)
Bases: AbsESPnetModel
An encoder-only SSL model.
We currently/will support the following SSL objectives:
- HuBERT
- Data2Vec (in development)
- DinoSR (in development)
- wav2vec 2.0 (TODO)
- w2v-BERT (TODO)
- BEST-RQ (TODO)
- Flow Matching (TODO)
Models can be trained with multiple objectives by adding multiple entries under loss_conf in the training configuration.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Dict[str, Tensor]
encode(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, use_final_output: bool = True) → Dict
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
inference_encode(speech: Tensor, speech_lengths: Tensor, use_mask: bool = False, use_final_output: bool = True) → Tuple[List[Tensor], Tensor]