Source code for espnet.lm.lm_utils

#!/usr/bin/env python3

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (

# This code is ported from the following implementation written in Torch.

import logging
import os
import random

import chainer
import h5py
import numpy as np
from import extension
from tqdm import tqdm

[docs]def load_dataset(path, label_dict, outdir=None): """Load and save HDF5 that contains a dataset and stats for LM Args: path (str): The path of an input text dataset file label_dict (dict[str, int]): dictionary that maps token label string to its ID number outdir (str): The path of an output dir Returns: tuple[list[np.ndarray], int, int]: Tuple of token IDs in np.int32 converted by `read_tokens` the number of tokens by `count_tokens`, and the number of OOVs by `count_tokens` """ if outdir is not None: os.makedirs(outdir, exist_ok=True) filename = outdir + "/" + os.path.basename(path) + ".h5" if os.path.exists(filename):"loading binary dataset: {filename}") f = h5py.File(filename, "r") return f["data"][:], f["n_tokens"][()], f["n_oovs"][()] else:"skip dump/load HDF5 because the output dir is not specified")"reading text dataset: {path}") ret = read_tokens(path, label_dict) n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"]) if outdir is not None:"saving binary dataset: {filename}") with h5py.File(filename, "w") as f: # data = f.create_dataset( "data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32) ) data[:] = ret f["n_tokens"] = n_tokens f["n_oovs"] = n_oovs return ret, n_tokens, n_oovs
[docs]def read_tokens(filename, label_dict): """Read tokens as a sequence of sentences :param str filename : The name of the input file :param dict label_dict : dictionary that maps token label string to its ID number :return list of ID sequences :rtype list """ data = [] unk = label_dict["<unk>"] for ln in tqdm(open(filename, "r", encoding="utf-8")): data.append( np.array( [label_dict.get(label, unk) for label in ln.split()], dtype=np.int32 ) ) return data
[docs]def count_tokens(data, unk_id=None): """Count tokens and oovs in token ID sequences. Args: data (list[np.ndarray]): list of token ID sequences unk_id (int): ID of unknown token Returns: tuple: tuple of number of token occurrences and number of oov tokens """ n_tokens = 0 n_oovs = 0 for sentence in data: n_tokens += len(sentence) if unk_id is not None: n_oovs += np.count_nonzero(sentence == unk_id) return n_tokens, n_oovs
[docs]def compute_perplexity(result): """Computes and add the perplexity to the LogReport :param dict result: The current observations """ # Routine to rewrite the result dictionary of LogReport to add perplexity values result["perplexity"] = np.exp(result["main/loss"] / result["main/count"]) if "validation/main/loss" in result: result["val_perplexity"] = np.exp(result["validation/main/loss"])
[docs]class ParallelSentenceIterator(chainer.dataset.Iterator): """Dataset iterator to create a batch of sentences. This iterator returns a pair of sentences, where one token is shifted between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' Sentence batches are made in order of longer sentences, and then randomly shuffled. """ def __init__( self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True ): self.dataset = dataset self.batch_size = batch_size # batch size # Number of completed sweeps over the dataset. In this case, it is # incremented if every word is visited at least once after the last # increment. self.epoch = 0 # True if the epoch is incremented at the last iteration. self.is_new_epoch = False self.repeat = repeat length = len(dataset) self.batch_indices = [] # make mini-batches if batch_size > 1: indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i])) bs = 0 while bs < length: be = min(bs + batch_size, length) # batch size is automatically reduced if the sentence length # is larger than max_length if max_length > 0: sent_length = len(dataset[indices[bs]]) be = min( be, bs + max(batch_size // (sent_length // max_length + 1), 1) ) self.batch_indices.append(np.array(indices[bs:be])) bs = be if shuffle: # shuffle batches random.shuffle(self.batch_indices) else: self.batch_indices = [np.array([i]) for i in range(length)] # NOTE: this is not a count of parameter updates. It is just a count of # calls of ``__next__``. self.iteration = 0 self.sos = sos self.eos = eos # use -1 instead of None internally self._previous_epoch_detail = -1.0 def __next__(self): # This iterator returns a list representing a mini-batch. Each item # indicates a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' # represented by token IDs. n_batches = len(self.batch_indices) if not self.repeat and self.iteration >= n_batches: # If not self.repeat, this iterator stops at the end of the first # epoch (i.e., when all words are visited once). raise StopIteration batch = [] for idx in self.batch_indices[self.iteration % n_batches]: batch.append( ( np.append([self.sos], self.dataset[idx]), np.append(self.dataset[idx], [self.eos]), ) ) self._previous_epoch_detail = self.epoch_detail self.iteration += 1 epoch = self.iteration // n_batches self.is_new_epoch = self.epoch < epoch if self.is_new_epoch: self.epoch = epoch return batch
[docs] def start_shuffle(self): random.shuffle(self.batch_indices)
@property def epoch_detail(self): # Floating point version of epoch. return self.iteration / len(self.batch_indices) @property def previous_epoch_detail(self): if self._previous_epoch_detail < 0: return None return self._previous_epoch_detail
[docs] def serialize(self, serializer): # It is important to serialize the state to be recovered on resume. self.iteration = serializer("iteration", self.iteration) self.epoch = serializer("epoch", self.epoch) try: self._previous_epoch_detail = serializer( "previous_epoch_detail", self._previous_epoch_detail ) except KeyError: # guess previous_epoch_detail for older version self._previous_epoch_detail = self.epoch + ( self.current_position - 1 ) / len(self.batch_indices) if self.epoch_detail > 0: self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0) else: self._previous_epoch_detail = -1.0
[docs]class MakeSymlinkToBestModel(extension.Extension): """Extension that makes a symbolic link to the best model :param str key: Key of value :param str prefix: Prefix of model files and link target :param str suffix: Suffix of link target """ def __init__(self, key, prefix="model", suffix="best"): super(MakeSymlinkToBestModel, self).__init__() self.best_model = -1 self.min_loss = 0.0 self.key = key self.prefix = prefix self.suffix = suffix def __call__(self, trainer): observation = trainer.observation if self.key in observation: loss = observation[self.key] if self.best_model == -1 or loss < self.min_loss: self.min_loss = loss self.best_model = trainer.updater.epoch src = "%s.%d" % (self.prefix, self.best_model) dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix)) if os.path.lexists(dest): os.remove(dest) os.symlink(src, dest)"best model is " + src)
[docs] def serialize(self, serializer): if isinstance(serializer, chainer.serializer.Serializer): serializer("_best_model", self.best_model) serializer("_min_loss", self.min_loss) serializer("_key", self.key) serializer("_prefix", self.prefix) serializer("_suffix", self.suffix) else: self.best_model = serializer("_best_model", -1) self.min_loss = serializer("_min_loss", 0.0) self.key = serializer("_key", "") self.prefix = serializer("_prefix", "model") self.suffix = serializer("_suffix", "best")
# TODO(Hori): currently it only works with character-word level LM. # need to consider any types of subwords-to-word mapping.
[docs]def make_lexical_tree(word_dict, subword_dict, word_unk): """Make a lexical tree to compute word-level probabilities""" # node [dict(subword_id -> node), word_id, word_set[start-1, end]] root = [{}, -1, None] for w, wid in word_dict.items(): if wid > 0 and wid != word_unk: # skip <blank> and <unk> if True in [c not in subword_dict for c in w]: # skip unknown subword continue succ = root[0] # get successors from root node for i, c in enumerate(w): cid = subword_dict[c] if cid not in succ: # if next node does not exist, make a new node succ[cid] = [{}, -1, (wid - 1, wid)] else: prev = succ[cid][2] succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid)) if i == len(w) - 1: # if word end, set word id succ[cid][1] = wid succ = succ[cid][0] # move to the child successors return root