espnet2.train.collate_fn.common_collate_fn
Less than 1 minute
espnet2.train.collate_fn.common_collate_fn
espnet2.train.collate_fn.common_collate_fn(data: Collection[Tuple[str, Dict[str, ndarray]]], float_pad_value: float | int = 0.0, int_pad_value: int = -32768, not_sequence: Collection[str] = ()) → Tuple[List[str], Dict[str, Tensor]]
Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from that of the dataset as they are.