Source code for espnet.scheduler.scheduler


import argparse

from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args

class _PrefixParser:
    def __init__(self, parser, prefix):
        self.parser = parser
        self.prefix = prefix

    def add_argument(self, name, **kwargs):
        assert name.startswith("--")
        self.parser.add_argument(self.prefix + name[2:], **kwargs)

[docs]class SchedulerInterface: """Scheduler interface.""" alias = "" def __init__(self, key: str, args: argparse.Namespace): """Initialize class.""" self.key = key prefix = key + "_" + self.alias + "_" for k, v in vars(args).items(): if k.startswith(prefix): setattr(self, k[len(prefix) :], v)
[docs] def get_arg(self, name): """Get argument without prefix.""" return getattr(self.args, f"{self.key}_{self.alias}_{name}")
[docs] @classmethod def add_arguments(cls, key: str, parser: argparse.ArgumentParser): """Add arguments for CLI.""" group = parser.add_argument_group(f"{cls.alias} scheduler") cls._add_arguments(_PrefixParser(parser=group, prefix=f"--{key}-{cls.alias}-")) return parser
@staticmethod def _add_arguments(parser: _PrefixParser): pass
[docs] @classmethod def build(cls, key: str, **kwargs): """Initialize this class with python-level args. Args: key (str): key of hyper parameter Returns: LMinterface: A new instance of LMInterface. """ def add(parser): return cls.add_arguments(key, parser) kwargs = {f"{key}_{cls.alias}_" + k: v for k, v in kwargs.items()} args = argparse.Namespace(**kwargs) args = fill_missing_args(args, add) return cls(key, args)
[docs] def scale(self, n_iter: int) -> float: """Scale at `n_iter`. Args: n_iter (int): number of current iterations. Returns: float: current scale of learning rate. """ raise NotImplementedError()
[docs]def register_scheduler(cls): """Register scheduler.""" SCHEDULER_DICT[cls.alias] = cls.__module__ + ":" + cls.__name__ return cls
[docs]def dynamic_import_scheduler(module): """Import Scheduler class dynamically. Args: module (str): module_name:class_name or alias in `SCHEDULER_DICT` Returns: type: Scheduler class """ model_class = dynamic_import(module, SCHEDULER_DICT) assert issubclass( model_class, SchedulerInterface ), f"{module} does not implement SchedulerInterface" return model_class
[docs]@register_scheduler class NoScheduler(SchedulerInterface): """Scheduler which does nothing.""" alias = "none"
[docs] def scale(self, n_iter): """Scale of lr.""" return 1.0
[docs]@register_scheduler class NoamScheduler(SchedulerInterface): """Warmup + InverseSqrt decay scheduler. Args: noam_warmup (int): number of warmup iterations. """ alias = "noam" @staticmethod def _add_arguments(parser: _PrefixParser): """Add scheduler args.""" parser.add_argument( "--warmup", type=int, default=1000, help="Number of warmup iterations." ) def __init__(self, key, args): """Initialize class.""" super().__init__(key, args) self.normalize = 1 / (self.warmup * self.warmup**-1.5)
[docs] def scale(self, step): """Scale of lr.""" step += 1 # because step starts from 0 return self.normalize * min(step**-0.5, step * self.warmup**-1.5)
[docs]@register_scheduler class CyclicCosineScheduler(SchedulerInterface): """Cyclic cosine annealing. Args: cosine_warmup (int): number of warmup iterations. cosine_total (int): number of total annealing iterations. Notes: Proposed in (and Used in the GPT2 config of Megatron-LM """ alias = "cosine" @staticmethod def _add_arguments(parser: _PrefixParser): """Add scheduler args.""" parser.add_argument( "--warmup", type=int, default=1000, help="Number of warmup iterations." ) parser.add_argument( "--total", type=int, default=100000, help="Number of total annealing iterations.", )
[docs] def scale(self, n_iter): """Scale of lr.""" import math return 0.5 * (math.cos(math.pi * (n_iter - self.warmup) / + 1)