espnet2.uasr.espnet_model.ESPnetUASRModel
espnet2.uasr.espnet_model.ESPnetUASRModel
class espnet2.uasr.espnet_model.ESPnetUASRModel(frontend: AbsFrontend | None, segmenter: AbsSegmenter | None, generator: AbsGenerator, discriminator: AbsDiscriminator, losses: Dict[str, AbsUASRLoss], kenlm_path: str | None, token_list: list | None, max_epoch: int | None, vocab_size: int, cfg: Dict | None = None, pad: int = 1, sil_token: str = '<SIL>', sos_token: str = '<s>', eos_token: str = '</s>', skip_softmax: str2bool = False, use_gumbel: str2bool = False, use_hard_gumbel: str2bool = True, min_temperature: float = 0.1, max_temperature: float = 2.0, decay_temperature: float = 0.99995, use_collected_training_feats: str2bool = False)
Bases: AbsESPnetModel
Unsupervised ASR model.
The source code is from FAIRSEQ: https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec/unsupervised
Initializes internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, **kwargs) → Dict[str, Tensor]
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, pseudo_labels: Tensor | None = None, pseudo_labels_lengths: Tensor | None = None, do_validation: str2bool | None = False, print_hyp: str2bool | None = False, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Segmenter + Generator + Discriminator + Calc Loss
Args:
get_optim_index()
inference(speech: Tensor, speech_lengths: Tensor)
is_discriminative_step()
property number_updates