espnet2.train.abs_espnet_model.AbsESPnetModel
espnet2.train.abs_espnet_model.AbsESPnetModel
class espnet2.train.abs_espnet_model.AbsESPnetModel(*args, **kwargs)
Bases: Module
, ABC
The common abstract class among each tasks
“ESPnetModel” is referred to a class which inherits torch.nn.Module, and makes the dnn-models forward as its member field, a.k.a delegate pattern, and defines “loss”, “stats”, and “weight” for the task.
If you intend to implement new task in ESPNet, the model must inherit this class. In other words, the “mediator” objects between our training system and the your task class are just only these three values, loss, stats, and weight.
Example
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
abstract collect_feats(**batch: Tensor) → Dict[str, Tensor]
abstract forward(**batch: Tensor) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.