espnet2.tasks.hubert.HubertTask
espnet2.tasks.hubert.HubertTask
class espnet2.tasks.hubert.HubertTask
Bases: AbsTask
classmethod add_task_arguments(parser: ArgumentParser)
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Return “collate_fn”, which is a callable object and given to DataLoader.
>>> from torch.utils.data import DataLoader
>>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
In many cases, you can use our common collate_fn.
classmethod build_model(args: Namespace) → HubertPretrainModel | TorchAudioHubertPretrainModel
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
class_choices_list : List[[ClassChoices](../train/ClassChoices.md#espnet2.train.class_choices.ClassChoices)] = [<espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>]
num_optimizers : int = 1
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the optional names by Task
This function is used by
cls.check_task_requirements() If your model is defined as follows,
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “optional_data_names” should be as
>>> optional_data_names = ('opt',)
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the required names by Task
This function is used by
cls.check_task_requirements() If your model is defined as following,
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “required_data_names” should be as
>>> required_data_names = ('input', 'output')
trainer
alias of Trainer