Source code for espnet2.main_funcs.average_nbest_models

import logging
import warnings
from pathlib import Path
from typing import Collection, Optional, Sequence, Union

import torch
from typeguard import typechecked

from espnet2.train.reporter import Reporter


[docs]@torch.no_grad() @typechecked def average_nbest_models( output_dir: Path, reporter: Reporter, best_model_criterion: Sequence[Sequence[str]], nbest: Union[Collection[int], int], suffix: Optional[str] = None, ) -> None: """Generate averaged model from n-best models Args: output_dir: The directory contains the model file for each epoch reporter: Reporter instance best_model_criterion: Give criterions to decide the best model. e.g. [("valid", "loss", "min"), ("train", "acc", "max")] nbest: Number of best model files to be averaged suffix: A suffix added to the averaged model file name """ if isinstance(nbest, int): nbests = [nbest] else: nbests = list(nbest) if len(nbests) == 0: warnings.warn("At least 1 nbest values are required") nbests = [1] if suffix is not None: suffix = suffix + "." else: suffix = "" # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] nbest_epochs = [ (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) for ph, k, m in best_model_criterion if reporter.has(ph, k) ] _loaded = {} for ph, cr, epoch_and_values in nbest_epochs: _nbests = [i for i in nbests if i <= len(epoch_and_values)] if len(_nbests) == 0: _nbests = [1] for n in _nbests: if n == 0: continue elif n == 1: # The averaged model is same as the best model e, _ = epoch_and_values[0] op = output_dir / f"{e}epoch.pth" sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth" if sym_op.is_symlink() or sym_op.exists(): sym_op.unlink() sym_op.symlink_to(op.name) else: op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth" logging.info( f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' ) avg = None # 2.a. Averaging model for e, _ in epoch_and_values[:n]: if e not in _loaded: _loaded[e] = torch.load( output_dir / f"{e}epoch.pth", map_location="cpu", ) states = _loaded[e] if avg is None: avg = states else: # Accumulated for k in avg: avg[k] = avg[k] + states[k] for k in avg: if str(avg[k].dtype).startswith("torch.int"): # For int type, not averaged, but only accumulated. # e.g. BatchNorm.num_batches_tracked # (If there are any cases that requires averaging # or the other reducing method, e.g. max/min, for integer type, # please report.) logging.info(f"Accumulating {k} instead of averaging") pass else: avg[k] = avg[k] / n # 2.b. Save the ave model and create a symlink torch.save(avg, op) # 3. *.*.ave.pth is a symlink to the max ave model op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth" sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth" if sym_op.is_symlink() or sym_op.exists(): sym_op.unlink() sym_op.symlink_to(op.name)