espnet2.train.trainer.Trainer
espnet2.train.trainer.Trainer
class espnet2.train.trainer.Trainer
Bases: object
Trainer having a optimizer.
If you’d like to use multiple optimizers, then inherit this class and override the methods if necessary - at least “train_one_epoch()”
>>> class TwoOptimizerTrainer(Trainer):
... @classmethod
... def add_arguments(cls, parser):
... ...
...
... @classmethod
... def train_one_epoch(cls, model, optimizers, ...):
... loss1 = model.model1(...)
... loss1.backward()
... optimizers[0].step()
...
... loss2 = model.model2(...)
... loss2.backward()
... optimizers[1].step()
classmethod add_arguments(parser: ArgumentParser)
Reserved for future development of another Trainer
classmethod build_options(args: Namespace) → TrainerOptions
Build options consumed by train(), eval(), and plot_attention()
classmethod plot_attention(model: Module, output_dir: Path | None, summary_writer, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], reporter: SubReporter, options: TrainerOptions) → None
static resume(checkpoint: str | Path, model: Module, reporter: Reporter, optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, ngpu: int = 0, strict: bool = True)
classmethod run(model: AbsESPnetModel, optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: AbsIterFactory | None, trainer_options, distributed_option: DistributedOption) → None
Perform training. This method performs the main process of training.
classmethod train_one_epoch(model: Module, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, reporter: SubReporter, summary_writer, options: TrainerOptions, distributed_option: DistributedOption) → bool
classmethod validate_one_epoch(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption) → None