Source code for espnet2.schedulers.piecewise_linear_warmup_lr

"""Piecewise linear warm up learning rate scheduler module."""

from typing import List, Union

import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import typechecked

from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler

[docs]class PiecewiseLinearWarmupLR(_LRScheduler, AbsBatchStepScheduler): """The PiecewiseLinearWarmupLR scheduler This scheduler is similar to WarmupLR Scheduler except that the warmup stage is piecewise linear. """ @typechecked def __init__( self, optimizer: torch.optim.Optimizer, warmup_steps_list: List[Union[int, float]] = [0, 25000], warmup_lr_list: List[float] = [0.0, 0.001], last_epoch: int = -1, ): self.warmup_steps_list = warmup_steps_list self.warmup_lr_list = warmup_lr_list # __init__() must be invoked before setting field # because step() is also invoked in __init__() super().__init__(optimizer, last_epoch) def __repr__(self): return ( f"{self.__class__.__name__}" f"(warmup_steps_list={self.warmup_steps_list}, " f"warmup_lr_list={self.warmup_lr_list})" )
[docs] def get_lr(self): step_num = self.last_epoch + 1 return [ np.interp( step_num, self.warmup_steps_list, self.warmup_lr_list, right=lr * self.warmup_steps_list[-1] ** 0.5 * step_num**-0.5, ) for lr in self.base_lrs ]