Task class and data input system for training
Task class and data input system for training
Task class
In ESpnet1, we have too many duplicated python modules. One of the big purposes of ESPnet2 is to provide a common interface and enable us to focus more on the unique parts of each task.
Task
class is a common system to build training tools for each task, ASR, TTS, LM, etc. inspired by Fairseq Task
idea. To build your task, only you have to do is just inheriting AbsTask
class:
from espnet2.tasks.abs_task import AbsTask
from espnet2.train.abs_espnet_model import AbsESPnetModel
class NewModel(ESPnetModel):
def forward(self, input, target):
(...)
# loss: The loss of the task. Must be a scalar value.
# stats: A dict object, used for logging and validation criterion
# weight: A scalar value that is used for normalization of loss and stats values among each mini-batches.
# In many cases, this value should be equal to the mini-batch-size
return loss, stats, weight
class NewTask(AbsTask):
@classmethod
def add_task_arguments(cls, parser):
parser.add_arguments(...)
(...)
@classmethod
def build_collate_fn(cls, args: argparse.Namespace)
(...)
@classmethod
def build_preprocess_fn(cls, args, train):
(...)
@classmethod
def required_data_names(cls, inference: bool = False):
(...)
@classmethod
def optional_data_names(cls, inference: bool = False):
(...)
@classmethod
def build_model(cls, args):
return NewModel(...)
if __name__ == "__main__":
# Start training
NewTask.main()
Data input system
Espnet2 also provides a command line interface to describe the training corpus. On the contrary, unlike fairseq
or training system such as pytorch-lightning
, our Task
class doesn't have an interface for building the dataset explicitly. This is because we aim at the task related to speech/text only, so we don't need such general system so far.
The following is an example of the command lint arguments:
python -m espnet2.bin.asr_train \
--train_data_path_and_name_and_type=/some/path/tr/wav.scp,speech,sound \
--train_data_path_and_name_and_type=/some/path/tr/token_int,text,text_int \
--valid_data_path_and_name_and_type=/some/path/dev/wav.scp,speech,sound \
--valid_data_path_and_name_and_type=/some/path/dev/token_int,text,text_int
First of all, our mini-batch is always a dict
object:
# In training iteration
for batch in iterator:
# e.g. batch = {"speech": ..., "text": ...}
# Forward
model(**batch)
Where the model
is same as the model built by Task.build_model()
.
You can flexibly construct this mini-batch object using --*_data_path_and_name_and_type
. --*_data_path_and_name_and_type
can be repeated as you need and each --*_data_path_and_name_and_type
corresponds to an element in the mini-batch. Also, keep in mind that there is no distinction between input and target data.
The argument of --train_data_path_and_name_and_type
should be given as three values separated by commas, like <file-path>,<key-name>,<file-format>
.
key-name
specify the key of dictfile-path
is a file/directory path for the data source.file-format
indicates the format of file specified byfile-path
. e.g.sound
,kaldi_ark
, or etc.
scp
file
You can show the supported file format using --help
option.
python -m espnet2.bin.asr_train --help
Almost all formats are referred as scp
file according to Kaldi-ASR. scp
is just a text file which has two columns for each line: The first indicates the sample id and the second is some value. e.g. file path, transcription, a sequence of numbers.
- format=npy
sample_id_a /some/path/a.npy sample_id_b /some/path/b.npy
- format=sound
sample_id_a /some/path/a.flac sample_id_b /some/path/a.wav
- format=kaldi_ark
sample_id_a /some/path/a.ark:1234 sample_id_b /some/path/a.ark:5678
- format=text_int
sample_id_a 10 2 4 4 sample_id_b 3 2 0 1 6 2
- format=text
sample_id_a hello world sample_id_b It is rainy today
required_data_names()
and optional_data_names()
Though an arbitrary dictionary can be created by this system, each task assumes that the specific key is given for a specific purpose. e.g. ASR Task requires speech
and text
keys and each value is used for input data and target data respectively. See again the methods of Task
class: required_data_names()
and optional_data_names()
.
class NewTask(AbsTask):
@classmethod
def required_data_names(cls, inference: bool = False):
if not inference:
retval = ("input", "target")
else:
retval = ("input",)
return retval
@classmethod
def optional_data_names(cls, inference: bool = False):
retval = ("auxially_feature",)
return retval
required_data_names()
determines the mandatory data names and optional_data_names()
gives optional data. It means that the other names are allowed to given by command line arguments.
# The following is the expected argument
python -m new_task \
--train_data_path_and_name_and_type=filepath,input,sometype \
--train_data_path_and_name_and_type=filepath,target,sometype \
--train_data_path_and_name_and_type=filepath,auxially_feature,sometype
# The following raises an error
python -m new_task \
--train_data_path_and_name_and_type=filepath,unknown,sometype
The intention of this system is just an assertion check, so if feel unnecessary, you can turn off this checking with --allow_variable_data_keys true
.
# Ignore assertion checking for data names
python -m new_task \
--train_data_path_and_name_and_type=filepath,unknown_name,sometype \
--allow_variable_data_keys true
Customize collate_fn
for PyTorch data loader
Task
class has a method to customize collate_fn
:
class NewTask(AbsTask):
@classmethod
def build_collate_fn(cls, args: argparse.Namespace):
...
collate_fn
is an argument of torch.utils.data.DataLoader
and it can modify the data which is received from data-loader. e.g.:
def collate_fn(data):
# data is a list of the return value of Dataset class:
modified_data = (...touch data)
return modified_data
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, collate_fn=collate_fn)
for modified_data in data_loader:
...
The type of argument is determined by the input dataset
class and our dataset is always espnet2.train.dataset.ESPnetDataset
, which the return value is a tuple of sample id and a dict of tensor,
batch = ("sample_id", {"speech": tensor, "text": tensor})
Therefore, the type is a list of dict of tensor.
data = [
("sample_id", {"speech": tensor, "text": tensor}),
("sample_id2", {"speech": tensor, "text": tensor}),
...
]
The return type of collate_fn is supposed to be a tuple of list and a dict of tensor in espnet2, so the collate_fn for Task
must transform the data type to it.
for ids, batch in data_loader:
model(**batch)
We provide common collate_fn and this function can support many cases, so you might not need to customize it. This collate_fn is aware of variable sequence features for seq2seq task:
- The first axis of the sequence tensor from dataset must be length axis: e.g. (Length, Dim), (Length, Dim, Dim2), or (Length, ...)
- It's not necessary to make the lengths of each sample unified and they are stacked with zero-padding.
- The value of padding can be changed.
from espnet2.train.collate_fn import CommonCollateFn @classmethod def build_collate_fn(cls, args): # float_pad_value is used for float-tensor and int_pad_value is used for int-tensor return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
- The value of padding can be changed.
- Tensors which represent the length of each samples are also appended
batch = {"speech": ..., "speech_lengths": ..., "text": ..., "text_lengths": ...}
- If the feature is not sequential data, this behavior can be disabled.
python -m new_task --train_data_path_and_name_and_type=filepath,foo,npy
@classmethod def build_collate_fn(cls, args): return CommonCollateFn(not_sequence=["foo"])