Source code for espnet.optimizer.chainer

"""Chainer optimizer builders."""

import argparse

import chainer
from chainer.optimizer_hooks import WeightDecay

from espnet.optimizer.factory import OptimizerFactoryInterface
from espnet.optimizer.parser import adadelta, adam, sgd


[docs]class AdamFactory(OptimizerFactoryInterface): """Adam factory."""
[docs] @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Register args.""" return adam(parser)
[docs] @staticmethod def from_args(target, args: argparse.Namespace): """Initialize optimizer from argparse Namespace. Args: target: for pytorch `model.parameters()`, for chainer `model` args (argparse.Namespace): parsed command-line args """ opt = chainer.optimizers.Adam( alpha=args.lr, beta1=args.beta1, beta2=args.beta2, ) opt.setup(target) opt.add_hook(WeightDecay(args.weight_decay)) return opt
[docs]class SGDFactory(OptimizerFactoryInterface): """SGD factory."""
[docs] @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Register args.""" return sgd(parser)
[docs] @staticmethod def from_args(target, args: argparse.Namespace): """Initialize optimizer from argparse Namespace. Args: target: for pytorch `model.parameters()`, for chainer `model` args (argparse.Namespace): parsed command-line args """ opt = chainer.optimizers.SGD( lr=args.lr, ) opt.setup(target) opt.add_hook(WeightDecay(args.weight_decay)) return opt
[docs]class AdadeltaFactory(OptimizerFactoryInterface): """Adadelta factory."""
[docs] @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Register args.""" return adadelta(parser)
[docs] @staticmethod def from_args(target, args: argparse.Namespace): """Initialize optimizer from argparse Namespace. Args: target: for pytorch `model.parameters()`, for chainer `model` args (argparse.Namespace): parsed command-line args """ opt = chainer.optimizers.AdaDelta( rho=args.rho, eps=args.eps, ) opt.setup(target) opt.add_hook(WeightDecay(args.weight_decay)) return opt
OPTIMIZER_FACTORY_DICT = { "adam": AdamFactory, "sgd": SGDFactory, "adadelta": AdadeltaFactory, }