"""Finetuning methods."""
import logging
import os
import re
from collections import OrderedDict
import torch
from espnet.asr.asr_utils import get_model_conf, torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.dynamic_import import dynamic_import
[docs]def freeze_modules(model, modules):
"""Freeze model parameters according to modules list.
Args:
model (torch.nn.Module): Main model.
modules (List): Specified module(s) to freeze.
Return:
model (torch.nn.Module) : Updated main model.
model_params (filter): Filtered model parameters.
"""
for mod, param in model.named_parameters():
if any(mod.startswith(m) for m in modules):
logging.warning(f"Freezing {mod}. It will not be updated during training.")
param.requires_grad = False
model_params = filter(lambda x: x.requires_grad, model.parameters())
return model, model_params
[docs]def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (Dict) : Main model state dict.
partial_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
(bool): Whether transfer learning is allowed.
"""
model_modules = []
partial_modules = []
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
model_modules += [(key_m, value_m.shape)]
model_modules = sorted(model_modules, key=lambda x: (x[0], x[1]))
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
partial_modules = sorted(partial_modules, key=lambda x: (x[0], x[1]))
module_match = model_modules == partial_modules
if not module_match:
logging.error(
"Some specified modules from the pre-trained model "
"don't match with the new model modules:"
)
logging.error(f"Pre-trained: {set(partial_modules) - set(model_modules)}")
logging.error(f"New model: {set(model_modules) - set(partial_modules)}")
exit(1)
return module_match
[docs]def get_partial_state_dict(model_state_dict, modules):
"""Create state dict with specified modules matching input model modules.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (Dict): Specified module(s) to transfer.
Return:
new_state_dict (Dict): State dict with specified modules weights.
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
[docs]def get_lm_state_dict(lm_state_dict):
"""Create compatible ASR decoder state dict from LM state dict.
Args:
lm_state_dict (Dict): Pre-trained LM state dict.
Return:
new_state_dict (Dict): State dict with compatible key names.
"""
new_state_dict = OrderedDict()
for key, value in list(lm_state_dict.items()):
if key == "predictor.embed.weight":
new_state_dict["dec.embed.weight"] = value
elif key.startswith("predictor.rnn."):
_split = key.split(".")
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
new_state_dict[new_key] = value
return new_state_dict
[docs]def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in model state dict.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
new_mods (List): Filtered module list.
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.error(
"Specified module(s) don't match or (partially match) "
f"available modules in model. You specified: {incorrect_mods}."
)
logging.error("The existing modules in model are:")
logging.error(f"{mods_model}")
exit(1)
return new_mods
[docs]def create_transducer_compatible_state_dict(
model_state_dict, encoder_type, encoder_units
):
"""Create a compatible transducer model state dict for transfer learning.
If RNN encoder modules from a non-Transducer model are found in
the pre-trained model state dict, the corresponding modules keys are
renamed for compatibility.
Args:
model_state_dict (Dict): Pre-trained model state dict
encoder_type (str): Type of pre-trained encoder.
encoder_units (int): Number of encoder units in pre-trained model.
Returns:
new_state_dict (Dict): Transducer compatible pre-trained model state dict.
"""
if encoder_type.endswith("p") or not encoder_type.endswith(("lstm", "gru")):
return model_state_dict
new_state_dict = OrderedDict()
rnn_key_name = "birnn" if "b" in encoder_type else "rnn"
for key, value in list(model_state_dict.items()):
if any(k in key for k in ["l_last", "nbrnn"]):
if "nbrnn" in key:
layer_name = rnn_key_name + re.search("_l([0-9]+)", key).group(1)
key = re.sub(
"_l([0-9]+)",
"_l0",
key.replace("nbrnn", layer_name),
)
if (encoder_units * 2) == value.size(-1):
value = value[:, :encoder_units] + value[:, encoder_units:]
new_state_dict[key] = value
return new_state_dict
[docs]def load_trained_model(model_path, training=True):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
training (bool): Training mode specification for transducer model.
Returns:
model (torch.nn.Module): Trained model.
train_args (Namespace): Trained model arguments.
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), "model.json")
)
logging.info(f"Reading model parameters from {model_path}")
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
# CTC Loss is not needed, default to builtin to prevent import errors
if hasattr(train_args, "ctc_type"):
train_args.ctc_type = "builtin"
model_class = dynamic_import(model_module)
if "transducer" in model_module:
model = model_class(idim, odim, train_args, training=training)
custom_torch_load(model_path, model, training=training)
else:
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
[docs]def get_trained_model_state_dict(model_path, new_is_transducer):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to trained model.
new_is_transducer (bool): Whether the new model is Transducer-based.
Return:
(Dict): Trained model state dict.
"""
logging.info(f"Reading model parameters from {model_path}")
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
if "rnnlm" in model_path:
return get_lm_state_dict(torch.load(model_path))
idim, odim, args = get_model_conf(model_path, conf_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert (
isinstance(model, MTInterface)
or isinstance(model, ASRInterface)
or isinstance(model, TTSInterface)
)
if new_is_transducer and "transducer" not in args.model_module:
return create_transducer_compatible_state_dict(
model.state_dict(),
args.etype,
args.eunits,
)
return model.state_dict()
[docs]def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load ASR/MT/TTS model with pre-trained weights for specified modules.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
args Namespace: Model arguments.
interface (ASRInterface|MTInterface|TTSInterface): Model interface.
Return:
main_model (torch.nn.Module): Model with pre-initialized weights.
"""
def print_new_keys(state_dict, modules, model_path):
logging.info(f"Loading {modules} from model: {model_path}")
for k in state_dict.keys():
logging.warning(f"Overriding module {k}")
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning("Model(s) found for pre-initialization.")
for model_path, modules in [
(enc_model_path, enc_modules),
(dec_model_path, dec_modules),
]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict = get_trained_model_state_dict(
model_path, "transducer" in args.model_module
)
modules = filter_modules(model_state_dict, modules)
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(
main_state_dict, partial_state_dict, modules
):
print_new_keys(partial_state_dict, modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.error(f"Specified model was not found: {model_path}")
exit(1)
main_model.load_state_dict(main_state_dict)
return main_model