import argparse
import copy
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple
import numpy as np
import torch
from typeguard import typechecked
from espnet2.asr.ctc import CTC
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.enh.espnet_enh_s2t_model import ESPnetEnhS2TModel
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.tasks.abs_task import AbsTask
from espnet2.tasks.asr import ASRTask
from espnet2.tasks.asr import decoder_choices as asr_decoder_choices_
from espnet2.tasks.asr import encoder_choices as asr_encoder_choices_
from espnet2.tasks.asr import frontend_choices, normalize_choices
from espnet2.tasks.asr import postencoder_choices as asr_postencoder_choices_
from espnet2.tasks.asr import preencoder_choices as asr_preencoder_choices_
from espnet2.tasks.asr import specaug_choices
from espnet2.tasks.diar import DiarizationTask
from espnet2.tasks.diar import attractor_choices as diar_attractor_choices_
from espnet2.tasks.diar import decoder_choices as diar_decoder_choices_
from espnet2.tasks.diar import encoder_choices as diar_encoder_choices_
from espnet2.tasks.diar import frontend_choices as diar_front_end_choices_
from espnet2.tasks.diar import label_aggregator_choices
from espnet2.tasks.diar import normalize_choices as diar_normalize_choices_
from espnet2.tasks.diar import specaug_choices as diar_specaug_choices_
from espnet2.tasks.enh import EnhancementTask
from espnet2.tasks.enh import decoder_choices as enh_decoder_choices_
from espnet2.tasks.enh import encoder_choices as enh_encoder_choices_
from espnet2.tasks.enh import mask_module_choices as enh_mask_module_choices_
from espnet2.tasks.enh import separator_choices as enh_separator_choices_
from espnet2.tasks.st import STTask
from espnet2.tasks.st import decoder_choices as st_decoder_choices_
from espnet2.tasks.st import encoder_choices as st_encoder_choices_
from espnet2.tasks.st import extra_asr_decoder_choices as st_extra_asr_decoder_choices_
from espnet2.tasks.st import extra_mt_decoder_choices as st_extra_mt_decoder_choices_
from espnet2.tasks.st import postencoder_choices as st_postencoder_choices_
from espnet2.tasks.st import preencoder_choices as st_preencoder_choices_
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.torch_utils.initialize import initialize
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import (
CommonPreprocessor,
CommonPreprocessor_multi,
MutliTokenizerCommonPreprocessor,
)
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none
# Enhancement
enh_encoder_choices = copy.deepcopy(enh_encoder_choices_)
enh_encoder_choices.name = "enh_encoder"
enh_decoder_choices = copy.deepcopy(enh_decoder_choices_)
enh_decoder_choices.name = "enh_decoder"
enh_separator_choices = copy.deepcopy(enh_separator_choices_)
enh_separator_choices.name = "enh_separator"
enh_mask_module_choices = copy.deepcopy(enh_mask_module_choices_)
enh_mask_module_choices.name = "enh_mask_module"
# ASR (also SLU)
asr_preencoder_choices = copy.deepcopy(asr_preencoder_choices_)
asr_preencoder_choices.name = "asr_preencoder"
asr_encoder_choices = copy.deepcopy(asr_encoder_choices_)
asr_encoder_choices.name = "asr_encoder"
asr_postencoder_choices = copy.deepcopy(asr_postencoder_choices_)
asr_postencoder_choices.name = "asr_postencoder"
asr_decoder_choices = copy.deepcopy(asr_decoder_choices_)
asr_decoder_choices.name = "asr_decoder"
# ST
st_preencoder_choices = copy.deepcopy(st_preencoder_choices_)
st_preencoder_choices.name = "st_preencoder"
st_encoder_choices = copy.deepcopy(st_encoder_choices_)
st_encoder_choices.name = "st_encoder"
st_postencoder_choices = copy.deepcopy(st_postencoder_choices_)
st_postencoder_choices.name = "st_postencoder"
st_decoder_choices = copy.deepcopy(st_decoder_choices_)
st_decoder_choices.name = "st_decoder"
st_extra_asr_decoder_choices = copy.deepcopy(st_extra_asr_decoder_choices_)
st_extra_asr_decoder_choices.name = "st_extra_asr_decoder"
st_extra_mt_decoder_choices = copy.deepcopy(st_extra_mt_decoder_choices_)
st_extra_mt_decoder_choices.name = "st_extra_mt_decoder"
# DIAR
diar_frontend_choices = copy.deepcopy(diar_front_end_choices_)
diar_frontend_choices.name = "diar_frontend"
diar_specaug_choices = copy.deepcopy(diar_specaug_choices_)
diar_specaug_choices.name = "diar_specaug"
diar_normalize_choices = copy.deepcopy(diar_normalize_choices_)
diar_normalize_choices.name = "diar_normalize"
diar_encoder_choices = copy.deepcopy(diar_encoder_choices_)
diar_encoder_choices.name = "diar_encoder"
diar_decoder_choices = copy.deepcopy(diar_decoder_choices_)
diar_decoder_choices.name = "diar_decoder"
diar_attractor_choices = copy.deepcopy(diar_attractor_choices_)
diar_attractor_choices.name = "diar_attractor"
MAX_REFERENCE_NUM = 100
name2task = dict(
enh=EnhancementTask,
asr=ASRTask,
st=STTask,
diar=DiarizationTask,
)
# More can be added to the following attributes
enh_attributes = [
"encoder",
"encoder_conf",
"separator",
"separator_conf",
"mask_module",
"mask_module_conf",
"decoder",
"decoder_conf",
"criterions",
]
asr_attributes = [
"token_list",
"input_size",
"frontend",
"frontend_conf",
"specaug",
"specaug_conf",
"normalize",
"normalize_conf",
"preencoder",
"preencoder_conf",
"encoder",
"encoder_conf",
"postencoder",
"postencoder_conf",
"decoder",
"decoder_conf",
"ctc_conf",
]
st_attributes = [
"token_list",
"src_token_list",
"input_size",
"frontend",
"frontend_conf",
"specaug",
"specaug_conf",
"normalize",
"normalize_conf",
"preencoder",
"preencoder_conf",
"encoder",
"encoder_conf",
"postencoder",
"postencoder_conf",
"decoder",
"decoder_conf",
"ctc_conf",
"extra_asr_decoder",
"extra_asr_decoder_conf",
"extra_mt_decoder",
"extra_mt_decoder_conf",
]
diar_attributes = [
"input_size",
"num_spk",
"frontend",
"frontend_conf",
"specaug",
"specaug_conf",
"normalize",
"normalize_conf",
"encoder",
"encoder_conf",
"decoder",
"decoder_conf",
"attractor",
"attractor_conf",
"label_aggregator",
"label_aggregator_conf",
]
[docs]class EnhS2TTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --enh_encoder and --enh_encoder_conf
enh_encoder_choices,
# --enh_separator and --enh_separator_conf
enh_separator_choices,
# --enh_decoder and --enh_decoder_conf
enh_decoder_choices,
# --enh_mask_module and --enh_mask_module_conf
enh_mask_module_choices,
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --asr_preencoder and --asr_preencoder_conf
asr_preencoder_choices,
# --asr_encoder and --asr_encoder_conf
asr_encoder_choices,
# --asr_postencoder and --asr_postencoder_conf
asr_postencoder_choices,
# --asr_decoder and --asr_decoder_conf
asr_decoder_choices,
# --st_preencoder and --st_preencoder_conf
st_preencoder_choices,
# --st_encoder and --st_encoder_conf
st_encoder_choices,
# --st_postencoder and --st_postencoder_conf
st_postencoder_choices,
# --st_decoder and --st_decoder_conf
st_decoder_choices,
# --st_extra_asr_decoder and --st_extra_asr_decoder_conf
st_extra_asr_decoder_choices,
# --st_extra_mt_decoder and --st_extra_mt_decoder_conf
st_extra_mt_decoder_choices,
# --diar_frontend and --diar_frontend_conf
diar_frontend_choices,
# --diar_specaug and --diar_specaug_conf
diar_specaug_choices,
# --diar_normalize and --diar_normalize_conf
diar_normalize_choices,
# --diar_encoder and --diar_encoder_conf
diar_encoder_choices,
# --diar_decoder and --diar_decoder_conf
diar_decoder_choices,
# --label_aggregator and --label_aggregator_conf
label_aggregator_choices,
# --diar_attractor and --diar_attractor_conf
diar_attractor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
[docs] @classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--src_token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token (for source language)",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group.add_argument(
"--ctc_conf",
action=NestedDictAction,
default=get_default_kwargs(CTC),
help="The keyword arguments for CTC class.",
)
group.add_argument(
"--enh_criterions",
action=NestedDictAction,
default=[
{
"name": "si_snr",
"conf": {},
"wrapper": "fixed_order",
"wrapper_conf": {},
},
],
help="The criterions binded with the loss wrappers.",
)
group.add_argument(
"--diar_num_spk",
type=int_or_none,
default=None,
help="The number of speakers (for each recording) for diar submodel class",
)
group.add_argument(
"--diar_input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group.add_argument(
"--enh_model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetEnhancementModel),
help="The keyword arguments for enh submodel class.",
)
group.add_argument(
"--asr_model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetASRModel),
help="The keyword arguments for asr submodel class.",
)
group.add_argument(
"--st_model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetEnhancementModel),
help="The keyword arguments for st submodel class.",
)
group.add_argument(
"--diar_model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetDiarizationModel),
help="The keyword arguments for diar submodel class.",
)
group.add_argument(
"--subtask_series",
type=str,
nargs="+",
default=("enh", "asr"),
choices=["enh", "asr", "st", "diar"],
help="The series of subtasks in the pipeline.",
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetEnhS2TModel),
help="The keyword arguments for model class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=False,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized " "in the specified level token",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
group.add_argument(
"--src_token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The source text will be tokenized " "in the specified level token",
)
group.add_argument(
"--src_bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece (for source language)",
)
group.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
group.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
group.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
group.add_argument(
"--text_name",
nargs="+",
default=["text"],
type=str,
help="Specify the text_name attribute used in the preprocessor",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
[docs] @classmethod
@typechecked
def build_collate_fn(cls, args: argparse.Namespace, train: bool) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
[docs] @classmethod
@typechecked
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
if args.use_preprocessor:
if "st" in args.subtask_series:
retval = MutliTokenizerCommonPreprocessor(
train=train,
token_type=[args.token_type, args.src_token_type],
token_list=[args.token_list, args.src_token_list],
bpemodel=[args.bpemodel, args.src_bpemodel],
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=getattr(args, "rir_scp", None),
rir_apply_prob=getattr(args, "rir_apply_prob", 1.0),
noise_scp=getattr(args, "noise_scp", None),
noise_apply_prob=getattr(args, "noise_apply_prob", 1.0),
noise_db_range=getattr(args, "noise_db_range", "13_15"),
short_noise_thres=getattr(args, "short_noise_thres", 0.5),
speech_volume_normalize=getattr(
args, "speech_volume_normalize", None
),
speech_name="speech",
text_name=["text", "src_text"],
**getattr(args, "preprocessor_conf", {}),
)
elif "diar" in args.subtask_series:
retval = CommonPreprocessor(
train=train, **getattr(args, "preprocessor_conf", {})
)
else:
retval = CommonPreprocessor_multi(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_name=getattr(args, "text_name", ["text"]),
text_cleaner=args.cleaner,
g2p_type=args.g2p,
**getattr(args, "preprocessor_conf", {}),
)
else:
retval = None
return retval
[docs] @classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "speech_ref1")
else:
# Recognition mode
retval = ("speech",)
return retval
[docs] @classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ["text", "dereverb_ref1"]
st = 2 if "speech_ref1" in retval else 1
retval += ["speech_ref{}".format(n) for n in range(st, MAX_REFERENCE_NUM + 1)]
retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["text_spk{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["src_text"]
retval = tuple(retval)
return retval
[docs] @classmethod
@typechecked
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhS2TModel:
# Build submodels in the order of subtask_series
model_conf = args.model_conf.copy()
for _, subtask in enumerate(args.subtask_series):
subtask_conf = dict(
init=None, model_conf=eval(f"args.{subtask}_model_conf")
)
for attr in eval(f"{subtask}_attributes"):
subtask_conf[attr] = (
getattr(args, subtask + "_" + attr, None)
if getattr(args, subtask + "_" + attr, None) is not None
else getattr(args, attr, None)
)
if subtask in ["asr", "st", "diar"]:
m_subtask = "s2t"
elif subtask in ["enh"]:
m_subtask = subtask
else:
raise ValueError(f"{subtask} not supported.")
logging.info(f"Building {subtask} task model, using config: {subtask_conf}")
model_conf[f"{m_subtask}_model"] = name2task[subtask].build_model(
argparse.Namespace(**subtask_conf)
)
# 8. Build model
model = ESPnetEnhS2TModel(**model_conf)
# FIXME(kamo): Should be done in model?
# 9. Initialize
if args.init is not None:
initialize(model, args.init)
return model