Source code for espnet.asr.chainer_backend.asr

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Training/decoding definition for the speech recognition task."""

import json
import logging
import os

# chainer related
import chainer
from chainer import training
from chainer.datasets import TransformDataset
from chainer.training import extensions

# rnnlm
import espnet.lm.chainer_backend.extlm as extlm_chainer
import espnet.lm.chainer_backend.lm as lm_chainer

# espnet related
from espnet.asr.asr_utils import (
    CompareValueTrigger,
    adadelta_eps_decay,
    add_results_to_json,
    chainer_load,
    get_model_conf,
    restore_snapshot,
)
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.deterministic_utils import set_deterministic_chainer
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import (
    ShufflingEnabler,
    ToggleableShufflingMultiprocessIterator,
    ToggleableShufflingSerialIterator,
)
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop


[docs]def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info("chainer version = " + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning("cuda is not available") if not chainer.cuda.cudnn_enabled: logging.warning("cudnn is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = "ctc" logging.info("Pure CTC mode") elif args.mtlalpha == 0.0: mtl_mode = "att" logging.info("Pure attention mode") else: mtl_mode = "mtl" logging.info("Multitask learning mode") # specify model architecture logging.info("import model module: " + args.model_module) model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, flag_return=False) assert isinstance(model, ASRInterface) total_subsampling_factor = model.get_total_subsampling_factor() # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps( (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info("single gpu calculation.") elif ngpu > 1: gpu_id = 0 devices = {"main": gpu_id} for gid in range(1, ngpu): devices["sub_%d" % gid] = gid logging.info("multi gpu calculation (#gpus = %d)." % ngpu) logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) else: gpu_id = -1 logging.info("cpu calculation") # Setup an optimizer if args.opt == "adadelta": optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == "adam": optimizer = chainer.optimizers.Adam() elif args.opt == "noam": optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise NotImplementedError("args.opt={}".format(args.opt)) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # Setup a converter converter = model.custom_converter(subsampling_factor=model.subsample[0]) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] # set up training iterator and updater load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 accum_grad = args.accum_grad if ngpu <= 1: # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad, ) ] else: train_iters = [ ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad, ) ] # set up updater updater = model.custom_updater( train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad, ) else: if args.batch_count not in ("auto", "seq") and args.batch_size == 0: raise NotImplementedError( "--batch-count 'bin' and 'frame' are not implemented " "in chainer multi gpu" ) # set up minibatches train_subsets = [] for gid in range(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset( train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, ) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in range(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad, ) for gid in range(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, shuffle=not use_sortagrad, ) for gid in range(ngpu) ] # set up updater updater = model.custom_parallel_updater( train_iters, optimizer, converter=converter, devices=devices ) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) if args.opt == "noam": from espnet.nets.chainer_backend.transformer.training import VaswaniRule trainer.extend( VaswaniRule( "alpha", d=args.adim, warmup_steps=args.transformer_warmup_steps, scale=args.transformer_lr, ), trigger=(1, "iteration"), ) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, ) else: valid_iter = chainer.iterators.SerialIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False ) # Evaluate the model with the test dataset for each epoch trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted( list(valid_json.items())[: args.num_save_attention], key=lambda x: int(x[1]["input"][0]["shape"][1]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class logging.info("Using custom PlotAttentionReport") att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=gpu_id, subsampling_factor=total_subsampling_factor, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"), trigger=(1, "epoch"), ) # Make a plot for training and validation values trainer.extend( extensions.PlotReport( [ "main/loss", "validation/main/loss", "main/loss_ctc", "validation/main/loss_ctc", "main/loss_att", "validation/main/loss_att", ], "epoch", file_name="loss.png", ) ) trainer.extend( extensions.PlotReport( ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" ) ) # Save best models trainer.extend( extensions.snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) if mtl_mode != "ctc": trainer.extend( extensions.snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc" and mtl_mode != "ctc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best"), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best"), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) ) report_keys = [ "epoch", "iteration", "main/loss", "main/loss_ctc", "main/loss_att", "validation/main/loss", "validation/main/loss_ctc", "validation/main/loss_att", "main/acc", "validation/main/acc", "elapsed_time", ] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main").eps ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": try: from tensorboardX import SummaryWriter except Exception: logging.error("Please install tensorboardx") raise writer = SummaryWriter(args.tensorboard_dir) trainer.extend( TensorboardLogger(writer, att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
[docs]def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info("chainer version = " + chainer.__version__) set_deterministic_chainer(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture logging.info("reading model parameters from " + args.model) # To be compatible with v.0.3.0 models if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "espnet.nets.chainer_backend.e2e_asr:E2E" model_class = dynamic_import(model_module) model = model_class(idim, odim, train_args) assert isinstance(model, ASRInterface) chainer_load(args.model, model) # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_chainer.ClassifierWithState( lm_chainer.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit ) ) chainer_load(args.rnnlm, rnnlm) else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_chainer.ClassifierWithState( lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit) ) chainer_load(args.word_rnnlm, word_rnnlm) if rnnlm is not None: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.MultiLevelLM( word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict ) ) else: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.LookAheadWordLM( word_rnnlm.predictor, word_dict, char_dict ) ) # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=( train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf ), preprocess_args={"train": False}, # Switch the mode of preprocessing ) # decode each utterance new_js = {} with chainer.no_backprop_mode(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )