espnet2.train.deepspeed_trainer.DeepSpeedTrainer
espnet2.train.deepspeed_trainer.DeepSpeedTrainer
class espnet2.train.deepspeed_trainer.DeepSpeedTrainer
Bases: Trainer
classmethod build_options(args: Namespace) → DeepSpeedTrainerOptions
Build options consumed by train(), eval(), and plot_attention()
static resume(model: None, reporter: Reporter, output_dir: Path)
classmethod run(model: AbsESPnetModel | None, train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, trainer_options: DeepSpeedTrainerOptions, **kwargs) → None
Perform training. This method performs the main process of training.
classmethod setup_data_dtype(deepspeed_config: Dict)
classmethod train_one_epoch(model, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], reporter: SubReporter, options: DeepSpeedTrainerOptions) → None
classmethod valid_one_epoch(model, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], reporter: SubReporter, options: DeepSpeedTrainerOptions) → None