Source code for espnet.lm.pytorch_backend.lm

#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py

"""LM training in pytorch."""

import copy
import json
import logging

import numpy as np
import torch
import torch.nn as nn
from chainer import Chain, reporter, training
from chainer.dataset import convert
from chainer.training import extensions
from torch.nn.parallel import data_parallel

from espnet.asr.asr_utils import (
    snapshot_object,
    torch_load,
    torch_resume,
    torch_snapshot,
)
from espnet.lm.lm_utils import (
    MakeSymlinkToBestModel,
    ParallelSentenceIterator,
    count_tokens,
    load_dataset,
    read_tokens,
)
from espnet.nets.lm_interface import LMInterface, dynamic_import_lm
from espnet.optimizer.factory import dynamic_import_optimizer
from espnet.scheduler.pytorch import PyTorchScheduler
from espnet.scheduler.scheduler import dynamic_import_scheduler
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop


[docs]def compute_perplexity(result): """Compute and add the perplexity to the LogReport. :param dict result: The current observations """ # Routine to rewrite the result dictionary of LogReport to add perplexity values result["perplexity"] = np.exp(result["main/nll"] / result["main/count"]) if "validation/main/nll" in result: result["val_perplexity"] = np.exp( result["validation/main/nll"] / result["validation/main/count"] )
[docs]class Reporter(Chain): """Dummy module to use chainer's trainer."""
[docs] def report(self, loss): """Report nothing.""" pass
[docs]def concat_examples(batch, device=None, padding=None): """Concat examples in minibatch. :param np.ndarray batch: The batch to concatenate :param int device: The device to send to :param Tuple[int,int] padding: The padding to use :return: (inputs, targets) :rtype (torch.Tensor, torch.Tensor) """ x, t = convert.concat_examples(batch, padding=padding) x = torch.from_numpy(x) t = torch.from_numpy(t) if device is not None and device >= 0: x = x.cuda(device) t = t.cuda(device) return x, t
[docs]class BPTTUpdater(training.StandardUpdater): """An updater for a pytorch LM.""" def __init__( self, train_iter, model, optimizer, schedulers, device, gradclip=None, use_apex=False, accum_grad=1, ): """Initialize class. Args: train_iter (chainer.dataset.Iterator): The train iterator model (LMInterface) : The model to update optimizer (torch.optim.Optimizer): The optimizer for training schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer` device (int): The device id gradclip (float): The gradient clipping value to use use_apex (bool): The flag to use Apex in backprop. accum_grad (int): The number of gradient accumulation. """ super(BPTTUpdater, self).__init__(train_iter, optimizer) self.model = model self.device = device self.gradclip = gradclip self.use_apex = use_apex self.scheduler = PyTorchScheduler(schedulers, optimizer) self.accum_grad = accum_grad # The core part of the update routine can be customized by overriding.
[docs] def update_core(self): """Update the model.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator("main") optimizer = self.get_optimizer("main") # Progress the dataset iterator for sentences at each iteration. self.model.zero_grad() # Clear the parameter gradients accum = {"loss": 0.0, "nll": 0.0, "count": 0} for _ in range(self.accum_grad): batch = train_iter.__next__() # Concatenate the token IDs to matrices and send them to the device # self.converter does this job # (it is chainer.dataset.concat_examples by default) x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: loss, nll, count = self.model(x, t) else: # apex does not support torch.nn.DataParallel loss, nll, count = data_parallel(self.model, (x, t), self.device) # backward loss = loss.mean() / self.accum_grad if self.use_apex: from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Backprop # accumulate stats accum["loss"] += float(loss) accum["nll"] += float(nll.sum()) accum["count"] += int(count.sum()) for k, v in accum.items(): reporter.report({k: v}, optimizer.target) if self.gradclip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) optimizer.step() # Update the parameters self.scheduler.step(n_iter=self.iteration)
[docs]class LMEvaluator(BaseEvaluator): """A custom evaluator for a pytorch LM.""" def __init__(self, val_iter, eval_model, reporter, device): """Initialize class. :param chainer.dataset.Iterator val_iter : The validation iterator :param LMInterface eval_model : The model to evaluate :param chainer.Reporter reporter : The observations reporter :param int device : The device id to use """ super(LMEvaluator, self).__init__(val_iter, reporter, device=-1) self.model = eval_model self.device = device
[docs] def evaluate(self): """Evaluate the model.""" val_iter = self.get_iterator("main") loss = 0 nll = 0 count = 0 self.model.eval() with torch.no_grad(): for batch in copy.copy(val_iter): x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: l, n, c = self.model(x, t) else: # apex does not support torch.nn.DataParallel l, n, c = data_parallel(self.model, (x, t), self.device) loss += float(l.sum()) nll += float(n.sum()) count += int(c.sum()) self.model.train() # report validation loss observation = {} with reporter.report_scope(observation): reporter.report({"loss": loss}, self.model.reporter) reporter.report({"nll": nll}, self.model.reporter) reporter.report({"count": count}, self.model.reporter) return observation
[docs]def train(args): """Train with the given args. :param Namespace args: The program arguments :param type model_class: LMInterface class for training """ model_class = dynamic_import_lm(args.model_module, args.backend) assert issubclass(model_class, LMInterface), "model should implement LMInterface" # display torch version logging.info("torch version = " + torch.__version__) set_deterministic_pytorch(args) # check cuda and cudnn availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get special label ids unk = args.char_list_dict["<unk>"] eos = args.char_list_dict["<eos>"] # read tokens as a sequence of sentences val, n_val_tokens, n_val_oovs = load_dataset( args.valid_label, args.char_list_dict, args.dump_hdf5_path ) train, n_train_tokens, n_train_oovs = load_dataset( args.train_label, args.char_list_dict, args.dump_hdf5_path ) logging.info("#vocab = " + str(args.n_vocab)) logging.info("#sentences in the training data = " + str(len(train))) logging.info("#tokens in the training data = " + str(n_train_tokens)) logging.info( "oov rate in the training data = %.2f %%" % (n_train_oovs / n_train_tokens * 100) ) logging.info("#sentences in the validation data = " + str(len(val))) logging.info("#tokens in the validation data = " + str(n_val_tokens)) logging.info( "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100) ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators batch_size = args.batchsize * max(args.ngpu, 1) if batch_size * args.accum_grad > args.batchsize: logging.info( f"batch size is automatically increased " f"({args.batchsize} -> {batch_size * args.accum_grad})" ) train_iter = ParallelSentenceIterator( train, batch_size, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad, ) val_iter = ParallelSentenceIterator( val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False ) epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) logging.info("#iterations per epoch = %d" % epoch_iters) logging.info("#total iterations = " + str(args.epoch * epoch_iters)) # Prepare an RNNLM model if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model_class(args.n_vocab, args).to(dtype=dtype) if args.ngpu > 0: model.to("cuda") gpu_id = list(range(args.ngpu)) else: gpu_id = [-1] # Save model conf to json model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode( "utf_8" ) ) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), ) ) # Set up an optimizer opt_class = dynamic_import_optimizer(args.opt, args.backend) optimizer = opt_class.from_args(model.parameters(), args) if args.schedulers is None: schedulers = [] else: schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers] # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux" ) raise e model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK reporter = Reporter() setattr(model, "reporter", reporter) setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) updater = BPTTUpdater( train_iter, model, optimizer, schedulers, gpu_id, gradclip=args.gradclip, use_apex=use_apex, accum_grad=args.accum_grad, ) trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) trainer.extend( extensions.LogReport( postprocess=compute_perplexity, trigger=(args.report_interval_iters, "iteration"), ) ) trainer.extend( extensions.PrintReport( [ "epoch", "iteration", "main/loss", "perplexity", "val_perplexity", "elapsed_time", ] ), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) # Save best models trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}")) trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}")) # T.Hori: MinValueTrigger should be used, but it fails when resuming trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model")) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"), ) if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(args.tensorboard_dir) trainer.extend( TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration") ) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info("test the best model") torch_load(args.outdir + "/rnnlm.model.best", model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info("#sentences in the test data = " + str(len(test))) logging.info("#tokens in the test data = " + str(n_test_tokens)) logging.info( "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100) ) test_iter = ParallelSentenceIterator( test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False ) evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) result = evaluator() compute_perplexity(result) logging.info(f"test perplexity: {result['perplexity']}")