espnet.nets.pytorch_backend.nets_utils.mask_by_length
Less than 1 minute
espnet.nets.pytorch_backend.nets_utils.mask_by_length
espnet.nets.pytorch_backend.nets_utils.mask_by_length(xs, lengths, fill=0)
Mask tensor according to length.
- Parameters:
- xs (Tensor) – Batch of input tensor (B, *).
- lengths (LongTensor or List) – Batch of lengths (B,).
- fill (int or float) – Value to fill masked part.
- Returns: Batch of masked input tensor (B, *).
- Return type: Tensor
Examples
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])