espnetez.trainer.Trainer
espnetez.trainer.Trainer
class espnetez.trainer.Trainer(task, train_config, output_dir, stats_dir, data_info=None, train_dump_dir=None, valid_dump_dir=None, train_dataset=None, valid_dataset=None, train_dataloader=None, valid_dataloader=None, build_model_fn=None, **kwargs)
Bases: object
Generic trainer class for ESPnet training.
This class is responsible for managing the training process of ESPnet models. It handles the configuration, dataset preparation, and the training loop. The Trainer class supports multiple input methods including dump directories, custom datasets, and dataloaders. It ensures that the provided arguments are consistent and valid before starting the training process.
train_config
Configuration for training, can be a dictionary or Namespace object.
- Type: Namespace
task_class
Task class instantiated from the provided task identifier.
- Type: Task
stats_dir
Directory where statistics for training and validation will be stored.
- Type: str
output_dir
Directory where model outputs will be saved.
Type: str
Parameters:
- task (str) – The task identifier used to retrieve the corresponding task class.
- train_config (Union *[*dict , Namespace ]) – Configuration for training.
- output_dir (str) – Directory for saving model outputs.
- stats_dir (str) – Directory for storing training statistics.
- data_info (dict , optional) – Information about the dataset paths and types.
- train_dump_dir (str , optional) – Directory containing training dump files.
- valid_dump_dir (str , optional) – Directory containing validation dump files.
- train_dataset (Dataset , optional) – Custom training dataset.
- valid_dataset (Dataset , optional) – Custom validation dataset.
- train_dataloader (DataLoader , optional) – DataLoader for training data.
- valid_dataloader (DataLoader , optional) – DataLoader for validation data.
- build_model_fn (callable , optional) – Function to build the model.
- **kwargs – Additional keyword arguments for configuring the training.
Raises:ValueError – If any of the argument validation checks fail.
######### Examples
>>> trainer = Trainer(
task='asr',
train_config={'batch_size': 32, 'learning_rate': 0.001},
output_dir='./output',
stats_dir='./stats',
train_dump_dir='./train_dump',
valid_dump_dir='./valid_dump'
)
>>> trainer.collect_stats() # Collect statistics from the dataset
>>> trainer.train() # Start the training processNOTE
Ensure that either dump directories, datasets, or dataloaders are specified as input parameters, but not a combination of them in conflicting ways.
Initialize an EZ training environment.
- Parameters:
- task (str) – Identifier of the EZ task to be trained.
- train_config (dict *|*argparse.Namespace) – Configuration for the training run. If adictionary, the key/value pairs are converted into an
argparse.Namespace. Any additional keyword arguments passed via**kwargsare merged intotrain_config. - output_dir (str *|*pathlib.Path) – Directory where the trained model and other artifacts will be written.
- stats_dir (str *|*pathlib.Path) – Directory where training statistics and logs should be stored.
- data_info (dict , optional) – Metadata describing the training and validation datasets. The structure can be either:
{"train": {...}, "valid": {...}}or a flat mapping where the same items are used for both splits. Each value must be a tuple(file_name, name, type). The types are found at: espnet2/train/dataset.py. - train_dump_dir (str *|*pathlib.Path , optional) – Path to the directory containing the training data dump files. Required if
data_infois provided. - valid_dump_dir (str *|*pathlib.Path , optional) – Path to the directory containing the validation data dump files. Required if
data_infois provided. - train_dataset (Dataset , optional) – A custom training
Datasetinstance supplied directly to the task. - valid_dataset (Dataset , optional) – A custom validation
Datasetinstance supplied directly to the task. - train_dataloader (DataLoader , optional) – A custom training
DataLoaderinstance. Mutually exclusive withtrain_dataset. - valid_dataloader (DataLoader , optional) – A custom validation
DataLoaderinstance. Mutually exclusive withvalid_dataset. - build_model_fn (Callable , optional) – Function that builds the model used by the task. If provided, it is stored on the task instance as
build_model_fn. - **kwargs – Additional configuration values that will be merged into
train_configif it is a dictionary, or set as attributes on the resultingargparse.Namespace.
- Raises:
- ValueError – If
train_configis neither adictnor anargparse.Namespace. - AssertionError – If required arguments are missing (e.g.
data_info,train_dump_dir,valid_dump_dirwhen custom datasets are not supplied).
- ValueError – If
Side Effects: : * Instantiates self.task_class by calling get_ez_task() with the <br/> provided task identifier. When custom datasets or dataloaders are supplied, get_ez_task is called with use_custom_dataset=True. * Sets self.train_config.train_data_path_and_name_and_type and self.train_config.valid_data_path_and_name_and_type when data_info is used. * Adds print_config and required attributes to self.train_config based on kwargs (default values are False and ["output_dir", "token_list"] respectively). * Stores stats_dir and output_dir on the instance. * If build_model_fn is provided, it is attached to the task class.
collect_stats()
Collect statistics for training and validation datasets.
This method initializes the process of gathering statistical data from the training and validation datasets. It creates the necessary directories to store the statistics if they do not already exist and sets the configuration parameters for collecting statistics. The statistics are used to define the shape files required for training.
The method will call the main function of the task_class with the updated configuration, which includes the output directory set to the statistics directory.
- Raises:OSError – If the directory for storing statistics cannot be created.
######### Examples
>>> trainer = Trainer(task='example_task', train_config=some_config,
output_dir='/path/to/output',
stats_dir='/path/to/stats')
>>> trainer.collect_stats()NOTE
This method must be called before training to ensure that the shape files are defined properly. After running this method, the train_shape_file and valid_shape_file attributes of train_config will be populated based on the collected statistics.
train()
Train the model using the specified training configuration.
This method orchestrates the training process by first ensuring that the necessary shape files are available. It checks for the presence of shape files in the specified statistics directory, and if they are found, it proceeds to invoke the main training routine of the task class.
- Raises:
- AssertionError – If no shape files are found in the statistics
- directory for either training or validation. –
######### Examples
>>> trainer = Trainer(task='my_task', train_config=my_train_config,
output_dir='output/', stats_dir='stats/')
>>> trainer.train() # Starts the training process