Source code for espnet.utils.training.train_utils

import logging

import chainer


[docs]def check_early_stop(trainer, epochs): """Checks an early stopping trigger and warns the user if it's the case :param trainer: The trainer used for training :param epochs: The maximum number of epochs """ end_epoch = trainer.updater.get_iterator("main").epoch if end_epoch < (epochs - 1): logging.warning( "Hit early stop at epoch " + str(end_epoch) + "\nYou can change the patience or set it to 0 to run all epochs" )
[docs]def set_early_stop(trainer, args, is_lm=False): """Sets the early stop trigger given the program arguments :param trainer: The trainer used for training :param args: The program arguments :param is_lm: If the trainer is for a LM (epoch instead of epochs) """ patience = args.patience criterion = args.early_stop_criterion epochs = args.epoch if is_lm else args.epochs mode = "max" if "acc" in criterion else "min" if patience > 0: trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger( monitor=criterion, mode=mode, patients=patience, max_trigger=(epochs, "epoch"), )