espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba
Less than 1 minute
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba
class espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba(blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1)
Bases: Module
RNNT Loss Numba
- Parameters:
- blank (int , optional) – blank label. Default: 0.
- reduction (string , optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ‘mean’
- fastemit_lambda – Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
- clamp – Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(acts, labels, act_lens, label_lens)
Forward RNNTLossNumba.
log_probs: Tensor of (batch x seqLength x labelLength x outputDim) : containing output from network
labels: 2 dimensional Tensor containing all the targets of the : batch with zero padded
act_lens: Tensor of size (batch) containing size of each output : sequence from the network
label_lens: Tensor of (batch) containing label length of each example