import argparse
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from typeguard import typechecked
from espnet2.lm.abs_model import AbsLM
from espnet2.lm.espnet_model import ESPnetLanguageModel
from espnet2.lm.espnet_model_multitask import ESPnetMultitaskLanguageModel
from espnet2.lm.huggingface_pretrained_opt_lm import HuggingfaceOPTModel
from espnet2.lm.seq_rnn_lm import SequentialRNNLM
from espnet2.lm.transformer_lm import TransformerLM
from espnet2.tasks.abs_task import AbsTask
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.torch_utils.initialize import initialize
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import CommonPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.types import str2bool, str_or_none
lm_choices = ClassChoices(
"lm",
classes=dict(
seq_rnn=SequentialRNNLM,
transformer=TransformerLM,
transformer_opt=HuggingfaceOPTModel,
),
type_check=AbsLM,
default="seq_rnn",
)
model_choices = ClassChoices(
"model",
classes=dict(
lm=ESPnetLanguageModel,
lm_multitask=ESPnetMultitaskLanguageModel,
),
type_check=AbsESPnetModel,
default="lm",
)
[docs]class LMTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
lm_choices,
# --model and --model_conf
model_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
[docs] @classmethod
@typechecked
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word"],
help="",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file fo sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
return parser
[docs] @classmethod
@typechecked
def build_collate_fn(cls, args: argparse.Namespace, train: bool) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
return CommonCollateFn(int_pad_value=0)
[docs] @classmethod
@typechecked
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
non_linguistic_symbols=args.non_linguistic_symbols,
)
else:
retval = None
return retval
[docs] @classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("text",)
return retval
[docs] @classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
return retval
[docs] @classmethod
@typechecked
def build_model(
cls, args: argparse.Namespace
) -> Union[ESPnetLanguageModel, ESPnetMultitaskLanguageModel]:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# "args" is saved as it is in a yaml file by BaseTask.main().
# Overwriting token_list to keep it as "portable".
args.token_list = token_list.copy()
elif isinstance(args.token_list, (tuple, list)):
token_list = args.token_list.copy()
else:
raise RuntimeError("token_list must be str or dict")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
# 1. Build LM model
lm_class = lm_choices.get_class(args.lm)
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
try:
model_class = model_choices.get_class(args.model)
if args.model == "lm_multitask":
extra_model_conf = dict(token_list=token_list)
else:
extra_model_conf = dict()
except AttributeError:
model_class = model_choices.get_class("lm")
extra_model_conf = dict()
model = model_class(
lm=lm, vocab_size=vocab_size, **args.model_conf, **extra_model_conf
)
# FIXME(kamo): Should be done in model?
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
if args.lm == "transformer_opt":
# loading opt parameters
model.lm.reload_pretrained_parameters()
return model