Source code for espnet.nets.pytorch_backend.transformer.optimizer

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Optimizer module."""

import torch


[docs]class NoamOpt(object): """Optim wrapper that implements rate.""" def __init__(self, model_size, factor, warmup, optimizer): """Construct an NoamOpt object.""" self.optimizer = optimizer self._step = 0 self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups
[docs] def step(self): """Update parameters and rate.""" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate self.optimizer.step()
[docs] def rate(self, step=None): """Implement `lrate` above.""" if step is None: step = self._step return ( self.factor * self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) )
[docs] def zero_grad(self): """Reset gradient.""" self.optimizer.zero_grad()
[docs] def state_dict(self): """Return state_dict.""" return { "_step": self._step, "warmup": self.warmup, "factor": self.factor, "model_size": self.model_size, "_rate": self._rate, "optimizer": self.optimizer.state_dict(), }
[docs] def load_state_dict(self, state_dict): """Load state_dict.""" for key, value in state_dict.items(): if key == "optimizer": self.optimizer.load_state_dict(state_dict["optimizer"]) else: setattr(self, key, value)
[docs]def get_std_opt(model_params, d_model, warmup, factor): """Get standard NoamOpt.""" base = torch.optim.Adam(model_params, lr=0, betas=(0.9, 0.98), eps=1e-9) return NoamOpt(d_model, factor, warmup, base)