espnetez.trainer.Trainer
espnetez.trainer.Trainer
class espnetez.trainer.Trainer(task, train_config, output_dir, stats_dir, data_info=None, train_dump_dir=None, valid_dump_dir=None, train_dataset=None, valid_dataset=None, train_dataloader=None, valid_dataloader=None, build_model_fn=None, **kwargs)
Bases: object
Generic trainer class for ESPnet training.
This class is responsible for managing the training process of ESPnet models. It handles the configuration, dataset preparation, and the training loop. The Trainer class supports multiple input methods including dump directories, custom datasets, and dataloaders. It ensures that the provided arguments are consistent and valid before starting the training process.
train_config
Configuration for training, can be a dictionary or Namespace object.
- Type: Namespace
task_class
Task class instantiated from the provided task identifier.
- Type: Task
stats_dir
Directory where statistics for training and validation will be stored.
- Type: str
output_dir
Directory where model outputs will be saved.
Type: str
Parameters:
- task (str) – The task identifier used to retrieve the corresponding task class.
- train_config (Union *[*dict , Namespace ]) – Configuration for training.
- output_dir (str) – Directory for saving model outputs.
- stats_dir (str) – Directory for storing training statistics.
- data_info (dict , optional) – Information about the dataset paths and types.
- train_dump_dir (str , optional) – Directory containing training dump files.
- valid_dump_dir (str , optional) – Directory containing validation dump files.
- train_dataset (Dataset , optional) – Custom training dataset.
- valid_dataset (Dataset , optional) – Custom validation dataset.
- train_dataloader (DataLoader , optional) – DataLoader for training data.
- valid_dataloader (DataLoader , optional) – DataLoader for validation data.
- build_model_fn (callable , optional) – Function to build the model.
- **kwargs – Additional keyword arguments for configuring the training.
Raises:ValueError – If any of the argument validation checks fail.
######### Examples
>>> trainer = Trainer(
task='asr',
train_config={'batch_size': 32, 'learning_rate': 0.001},
output_dir='./output',
stats_dir='./stats',
train_dump_dir='./train_dump',
valid_dump_dir='./valid_dump'
)
>>> trainer.collect_stats() # Collect statistics from the dataset
>>> trainer.train() # Start the training process
NOTE
Ensure that either dump directories, datasets, or dataloaders are specified as input parameters, but not a combination of them in conflicting ways.
collect_stats()
Collects statistics for training and validation datasets.
This method initializes the process of gathering statistical data from the training and validation datasets. It creates the necessary directories to store the statistics if they do not already exist and sets the configuration parameters for collecting statistics. The statistics are used to define the shape files required for training.
The method will call the main function of the task_class with the updated configuration, which includes the output directory set to the statistics directory.
- Raises:OSError – If the directory for storing statistics cannot be created.
######### Examples
>>> trainer = Trainer(task='example_task', train_config=some_config,
output_dir='/path/to/output',
stats_dir='/path/to/stats')
>>> trainer.collect_stats()
NOTE
This method must be called before training to ensure that the shape files are defined properly. After running this method, the train_shape_file and valid_shape_file attributes of train_config will be populated based on the collected statistics.
train()
Train the model using the specified training configuration.
This method orchestrates the training process by first ensuring that the necessary shape files are available. It checks for the presence of shape files in the specified statistics directory, and if they are found, it proceeds to invoke the main training routine of the task class.
- Raises:
- AssertionError – If no shape files are found in the statistics
- directory for either training or validation. –
######### Examples
>>> trainer = Trainer(task='my_task', train_config=my_train_config,
output_dir='output/', stats_dir='stats/')
>>> trainer.train() # Starts the training process