espnet2.tasks.enh.EnhancementTask
espnet2.tasks.enh.EnhancementTask
class espnet2.tasks.enh.EnhancementTask
Bases: AbsTask
classmethod add_task_arguments(parser: ArgumentParser)
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Return “collate_fn”, which is a callable object and given to DataLoader.
>>> from torch.utils.data import DataLoader
>>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
In many cases, you can use our common collate_fn.
classmethod build_iter_factory(args: Namespace, distributed_option: DistributedOption, mode: str, kwargs: dict | None = None) → AbsIterFactory
Build a factory object of mini-batch iterator.
This object is invoked at every epochs to build the iterator for each epoch as following:
>>> iter_factory = cls.build_iter_factory(...)
>>> for epoch in range(1, max_epoch):
... for keys, batch in iter_fatory.build_iter(epoch):
... model(**batch)
The mini-batches for each epochs are fully controlled by this class. Note that the random seed used for shuffling is decided as “seed + epoch” and the generated mini-batches can be reproduces when resuming.
Note that the definition of “epoch” doesn’t always indicate to run out of the whole training corpus. “–num_iters_per_epoch” option restricts the number of iterations for each epoch and the rest of samples for the originally epoch are left for the next epoch. e.g. If The number of mini-batches equals to 4, the following two are same:
- 1 epoch without “–num_iters_per_epoch”
- 4 epoch with “–num_iters_per_epoch” == 1
classmethod build_model(args: Namespace) → ESPnetEnhancementModel
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
class_choices_list : List[[ClassChoices](../train/ClassChoices.md#espnet2.train.class_choices.ClassChoices)] = [<espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>, <espnet2.train.class_choices.ClassChoices object>]
num_optimizers : int = 1
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the optional names by Task
This function is used by
cls.check_task_requirements() If your model is defined as follows,
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “optional_data_names” should be as
>>> optional_data_names = ('opt',)
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the required names by Task
This function is used by
cls.check_task_requirements() If your model is defined as following,
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “required_data_names” should be as
>>> required_data_names = ('input', 'output')
trainer
alias of Trainer