espnet2.torch_utils.device_funcs.force_gatherable
Less than 1 minute
espnet2.torch_utils.device_funcs.force_gatherable
espnet2.torch_utils.device_funcs.force_gatherable(data, device)
Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int value is found.
The restriction to the returned value in DataParallel: : The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning. or a list, tuple, dict.