Source code for espnet2.lm.huggingface_pretrained_opt_lm

import copy
import logging
from typing import Any, List, Tuple

import torch
import torch.nn as nn
from typeguard import typechecked

from espnet2.lm.abs_model import AbsLM
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask


[docs]class HuggingfaceOPTModel(AbsLM): @typechecked def __init__( self, vocab_size: int, opt_name: str, ): super().__init__() try: from transformers import OPTModel except Exception as e: print("Error: transformers is not properly installed.") print("Please install transformers") raise e # opt_model_name_pattern = re.compile(r"facebook/opt-\d+m") # assert opt_model_name_pattern.match(opt_name) is not None pretrained_opt_model = OPTModel.from_pretrained(opt_name) pretrained_opt_model_dict = pretrained_opt_model.state_dict() pretrained_opt_model_dict.pop("decoder.embed_tokens.weight") self.pretrained_params = copy.deepcopy(pretrained_opt_model_dict) config = pretrained_opt_model.config config.vocab_size = vocab_size config.bos_token_id = vocab_size - 1 config.eos_token_id = vocab_size - 1 config.pad_token_id = 0 self.decoder = OPTModel(config) self.lm_head = nn.Linear( config.word_embed_proj_dim, config.vocab_size, bias=False ) def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m
[docs] def forward(self, input: torch.Tensor, hidden: None) -> Tuple[torch.Tensor, None]: """Compute LM loss value from buffer sequences. Args: input (torch.Tensor): Input ids. (batch, len) hidden (torch.Tensor): Target ids. (batch, len) """ pad_mask = input != 0 y = self.decoder( input, attention_mask=pad_mask, return_dict=True, ) y = y.last_hidden_state logits = self.lm_head(y) return logits, None
[docs] def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (vocab_size) and next state for ys """ if state is None: _use_cache = True else: _use_cache = False y = y.unsqueeze(0) output = self.decoder( y, past_key_values=state, use_cache=_use_cache, output_hidden_states=True, return_dict=True, ) h = output.last_hidden_state[:, -1] h = self.lm_head(h) cache = output.past_key_values logp = h.log_softmax(dim=-1).squeeze(0) return logp, cache
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, vocab_size)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoder.decoder.layers) if states[0] is None: _use_cache = True else: _use_cache = False # batch decoding output = self.decoder( ys, use_cache=_use_cache, output_hidden_states=True, return_dict=True, ) h = output.last_hidden_state h = self.lm_head(h[:, -1]) logp = h.log_softmax(dim=-1) state_list = [[[] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list
[docs] def reload_pretrained_parameters(self): self.decoder.load_state_dict(self.pretrained_params, strict=False) logging.info("Pretrained OPT model parameters reloaded!")