#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
"""LM training in pytorch."""
import copy
import json
import logging
import numpy as np
import torch
import torch.nn as nn
from chainer import Chain, reporter, training
from chainer.dataset import convert
from chainer.training import extensions
from torch.nn.parallel import data_parallel
from espnet.asr.asr_utils import (
snapshot_object,
torch_load,
torch_resume,
torch_snapshot,
)
from espnet.lm.lm_utils import (
MakeSymlinkToBestModel,
ParallelSentenceIterator,
count_tokens,
load_dataset,
read_tokens,
)
from espnet.nets.lm_interface import LMInterface, dynamic_import_lm
from espnet.optimizer.factory import dynamic_import_optimizer
from espnet.scheduler.pytorch import PyTorchScheduler
from espnet.scheduler.scheduler import dynamic_import_scheduler
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
[docs]def compute_perplexity(result):
"""Compute and add the perplexity to the LogReport.
:param dict result: The current observations
"""
# Routine to rewrite the result dictionary of LogReport to add perplexity values
result["perplexity"] = np.exp(result["main/nll"] / result["main/count"])
if "validation/main/nll" in result:
result["val_perplexity"] = np.exp(
result["validation/main/nll"] / result["validation/main/count"]
)
[docs]class Reporter(Chain):
"""Dummy module to use chainer's trainer."""
[docs] def report(self, loss):
"""Report nothing."""
pass
[docs]def concat_examples(batch, device=None, padding=None):
"""Concat examples in minibatch.
:param np.ndarray batch: The batch to concatenate
:param int device: The device to send to
:param Tuple[int,int] padding: The padding to use
:return: (inputs, targets)
:rtype (torch.Tensor, torch.Tensor)
"""
x, t = convert.concat_examples(batch, padding=padding)
x = torch.from_numpy(x)
t = torch.from_numpy(t)
if device is not None and device >= 0:
x = x.cuda(device)
t = t.cuda(device)
return x, t
[docs]class BPTTUpdater(training.StandardUpdater):
"""An updater for a pytorch LM."""
def __init__(
self,
train_iter,
model,
optimizer,
schedulers,
device,
gradclip=None,
use_apex=False,
accum_grad=1,
):
"""Initialize class.
Args:
train_iter (chainer.dataset.Iterator): The train iterator
model (LMInterface) : The model to update
optimizer (torch.optim.Optimizer): The optimizer for training
schedulers (espnet.scheduler.scheduler.SchedulerInterface):
The schedulers of `optimizer`
device (int): The device id
gradclip (float): The gradient clipping value to use
use_apex (bool): The flag to use Apex in backprop.
accum_grad (int): The number of gradient accumulation.
"""
super(BPTTUpdater, self).__init__(train_iter, optimizer)
self.model = model
self.device = device
self.gradclip = gradclip
self.use_apex = use_apex
self.scheduler = PyTorchScheduler(schedulers, optimizer)
self.accum_grad = accum_grad
# The core part of the update routine can be customized by overriding.
[docs] def update_core(self):
"""Update the model."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
# Progress the dataset iterator for sentences at each iteration.
self.model.zero_grad() # Clear the parameter gradients
accum = {"loss": 0.0, "nll": 0.0, "count": 0}
for _ in range(self.accum_grad):
batch = train_iter.__next__()
# Concatenate the token IDs to matrices and send them to the device
# self.converter does this job
# (it is chainer.dataset.concat_examples by default)
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
if self.device[0] == -1:
loss, nll, count = self.model(x, t)
else:
# apex does not support torch.nn.DataParallel
loss, nll, count = data_parallel(self.model, (x, t), self.device)
# backward
loss = loss.mean() / self.accum_grad
if self.use_apex:
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward() # Backprop
# accumulate stats
accum["loss"] += float(loss)
accum["nll"] += float(nll.sum())
accum["count"] += int(count.sum())
for k, v in accum.items():
reporter.report({k: v}, optimizer.target)
if self.gradclip is not None:
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
optimizer.step() # Update the parameters
self.scheduler.step(n_iter=self.iteration)
[docs]class LMEvaluator(BaseEvaluator):
"""A custom evaluator for a pytorch LM."""
def __init__(self, val_iter, eval_model, reporter, device):
"""Initialize class.
:param chainer.dataset.Iterator val_iter : The validation iterator
:param LMInterface eval_model : The model to evaluate
:param chainer.Reporter reporter : The observations reporter
:param int device : The device id to use
"""
super(LMEvaluator, self).__init__(val_iter, reporter, device=-1)
self.model = eval_model
self.device = device
[docs] def evaluate(self):
"""Evaluate the model."""
val_iter = self.get_iterator("main")
loss = 0
nll = 0
count = 0
self.model.eval()
with torch.no_grad():
for batch in copy.copy(val_iter):
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
if self.device[0] == -1:
l, n, c = self.model(x, t)
else:
# apex does not support torch.nn.DataParallel
l, n, c = data_parallel(self.model, (x, t), self.device)
loss += float(l.sum())
nll += float(n.sum())
count += int(c.sum())
self.model.train()
# report validation loss
observation = {}
with reporter.report_scope(observation):
reporter.report({"loss": loss}, self.model.reporter)
reporter.report({"nll": nll}, self.model.reporter)
reporter.report({"count": count}, self.model.reporter)
return observation
[docs]def train(args):
"""Train with the given args.
:param Namespace args: The program arguments
:param type model_class: LMInterface class for training
"""
model_class = dynamic_import_lm(args.model_module, args.backend)
assert issubclass(model_class, LMInterface), "model should implement LMInterface"
# display torch version
logging.info("torch version = " + torch.__version__)
set_deterministic_pytorch(args)
# check cuda and cudnn availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get special label ids
unk = args.char_list_dict["<unk>"]
eos = args.char_list_dict["<eos>"]
# read tokens as a sequence of sentences
val, n_val_tokens, n_val_oovs = load_dataset(
args.valid_label, args.char_list_dict, args.dump_hdf5_path
)
train, n_train_tokens, n_train_oovs = load_dataset(
args.train_label, args.char_list_dict, args.dump_hdf5_path
)
logging.info("#vocab = " + str(args.n_vocab))
logging.info("#sentences in the training data = " + str(len(train)))
logging.info("#tokens in the training data = " + str(n_train_tokens))
logging.info(
"oov rate in the training data = %.2f %%"
% (n_train_oovs / n_train_tokens * 100)
)
logging.info("#sentences in the validation data = " + str(len(val)))
logging.info("#tokens in the validation data = " + str(n_val_tokens))
logging.info(
"oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
)
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# Create the dataset iterators
batch_size = args.batchsize * max(args.ngpu, 1)
if batch_size * args.accum_grad > args.batchsize:
logging.info(
f"batch size is automatically increased "
f"({args.batchsize} -> {batch_size * args.accum_grad})"
)
train_iter = ParallelSentenceIterator(
train,
batch_size,
max_length=args.maxlen,
sos=eos,
eos=eos,
shuffle=not use_sortagrad,
)
val_iter = ParallelSentenceIterator(
val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
)
epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
logging.info("#iterations per epoch = %d" % epoch_iters)
logging.info("#total iterations = " + str(args.epoch * epoch_iters))
# Prepare an RNNLM model
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model_class(args.n_vocab, args).to(dtype=dtype)
if args.ngpu > 0:
model.to("cuda")
gpu_id = list(range(args.ngpu))
else:
gpu_id = [-1]
# Save model conf to json
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
"utf_8"
)
)
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Set up an optimizer
opt_class = dynamic_import_optimizer(args.opt, args.backend)
optimizer = opt_class.from_args(model.parameters(), args)
if args.schedulers is None:
schedulers = []
else:
schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
reporter = Reporter()
setattr(model, "reporter", reporter)
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
updater = BPTTUpdater(
train_iter,
model,
optimizer,
schedulers,
gpu_id,
gradclip=args.gradclip,
use_apex=use_apex,
accum_grad=args.accum_grad,
)
trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
trainer.extend(
extensions.LogReport(
postprocess=compute_perplexity,
trigger=(args.report_interval_iters, "iteration"),
)
)
trainer.extend(
extensions.PrintReport(
[
"epoch",
"iteration",
"main/loss",
"perplexity",
"val_perplexity",
"elapsed_time",
]
),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
# Save best models
trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}"))
trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
# T.Hori: MinValueTrigger should be used, but it fails when resuming
trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
)
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
set_early_stop(trainer, args, is_lm=True)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(args.tensorboard_dir)
trainer.extend(
TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
)
trainer.run()
check_early_stop(trainer, args.epoch)
# compute perplexity for test set
if args.test_label:
logging.info("test the best model")
torch_load(args.outdir + "/rnnlm.model.best", model)
test = read_tokens(args.test_label, args.char_list_dict)
n_test_tokens, n_test_oovs = count_tokens(test, unk)
logging.info("#sentences in the test data = " + str(len(test)))
logging.info("#tokens in the test data = " + str(n_test_tokens))
logging.info(
"oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
)
test_iter = ParallelSentenceIterator(
test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
)
evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
result = evaluator()
compute_perplexity(result)
logging.info(f"test perplexity: {result['perplexity']}")