Source code for espnet.lm.pytorch_backend.extlm

#!/usr/bin/env python3

# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from espnet.lm.lm_utils import make_lexical_tree
from espnet.nets.pytorch_backend.nets_utils import to_device


# Definition of a multi-level (subword/word) language model
[docs]class MultiLevelLM(nn.Module): logzero = -10000000000.0 zero = 1.0e-10 def __init__( self, wordlm, subwordlm, word_dict, subword_dict, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True, ): super(MultiLevelLM, self).__init__() self.wordlm = wordlm self.subwordlm = subwordlm self.word_eos = word_dict["<eos>"] self.word_unk = word_dict["<unk>"] self.var_word_eos = torch.LongTensor([self.word_eos]) self.var_word_unk = torch.LongTensor([self.word_unk]) self.space = subword_dict["<space>"] self.eos = subword_dict["<eos>"] self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) self.log_oov_penalty = math.log(oov_penalty) self.open_vocab = open_vocab self.subword_dict_size = len(subword_dict) self.subwordlm_weight = subwordlm_weight self.normalized = True
[docs] def forward(self, state, x): # update state with input label x if state is None: # make initial states and log-prob vectors self.var_word_eos = to_device(x, self.var_word_eos) self.var_word_unk = to_device(x, self.var_word_eos) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) wlm_logprobs = F.log_softmax(z_wlm, dim=1) clm_state, z_clm = self.subwordlm(None, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight new_node = self.lexroot clm_logprob = 0.0 xi = self.space else: clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[1] >= 0: # check if the node is word end w = to_device(x, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and log-prob vector wlm_state, z_wlm = self.wordlm(wlm_state, w) wlm_logprobs = F.log_softmax(z_wlm, dim=1) new_node = self.lexroot # move to the tree root clm_logprob = 0.0 elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] clm_logprob += log_y[0, xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None clm_logprob += log_y[0, xi] else: # if open_vocab flag is disabled, return 0 probabilities log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero) ) return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y clm_state, z_clm = self.subwordlm(clm_state, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight # apply word-level probabilies for <space> and <eos> labels if xi != self.space: if new_node is not None and new_node[1] >= 0: # if new node is word end wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob else: wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty log_y[:, self.space] = wlm_logprob log_y[:, self.eos] = wlm_logprob else: log_y[:, self.space] = self.logzero log_y[:, self.eos] = self.logzero return ( (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), log_y, )
[docs] def final(self, state): clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state if node is not None and node[1] >= 0: # check if the node is word end w = to_device(wlm_logprobs, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk wlm_state, z_wlm = self.wordlm(wlm_state, w) return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
# Definition of a look-ahead word language model
[docs]class LookAheadWordLM(nn.Module): logzero = -10000000000.0 zero = 1.0e-10 def __init__( self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True ): super(LookAheadWordLM, self).__init__() self.wordlm = wordlm self.word_eos = word_dict["<eos>"] self.word_unk = word_dict["<unk>"] self.var_word_eos = torch.LongTensor([self.word_eos]) self.var_word_unk = torch.LongTensor([self.word_unk]) self.space = subword_dict["<space>"] self.eos = subword_dict["<eos>"] self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.subword_dict_size = len(subword_dict) self.zero_tensor = torch.FloatTensor([self.zero]) self.normalized = True
[docs] def forward(self, state, x): # update state with input label x if state is None: # make initial states and cumlative probability vector self.var_word_eos = to_device(x, self.var_word_eos) self.var_word_unk = to_device(x, self.var_word_eos) self.zero_tensor = to_device(x, self.zero_tensor) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot xi = self.space else: wlm_state, cumsum_probs, node = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[1] >= 0: # check if the node is word end w = to_device(x, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and cumlative probability vector wlm_state, z_wlm = self.wordlm(wlm_state, w) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot # move to the tree root elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None else: # if open_vocab flag is disabled, return 0 probabilities log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero) ) return (wlm_state, None, None), log_y if new_node is not None: succ, wid, wids = new_node # compute parent node probability sum_prob = ( (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) if wids is not None else 1.0 ) if sum_prob < self.zero: log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero) ) return (wlm_state, cumsum_probs, new_node), log_y # set <unk> probability as a default value unk_prob = ( cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] ) y = to_device( x, torch.full( (1, self.subword_dict_size), float(unk_prob) * self.oov_penalty ), ) # compute transition probabilities to child nodes for cid, nd in succ.items(): y[:, cid] = ( cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]] ) / sum_prob # apply word-level probabilies for <space> and <eos> labels if wid >= 0: wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob y[:, self.space] = wlm_prob y[:, self.eos] = wlm_prob elif xi == self.space: y[:, self.space] = self.zero y[:, self.eos] = self.zero log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0) else: # if no path in the tree, transition probability is one log_y = to_device(x, torch.zeros(1, self.subword_dict_size)) return (wlm_state, cumsum_probs, new_node), log_y
[docs] def final(self, state): wlm_state, cumsum_probs, node = state if node is not None and node[1] >= 0: # check if the node is word end w = to_device(cumsum_probs, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk wlm_state, z_wlm = self.wordlm(wlm_state, w) return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])