espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_grad_kernel
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_grad_kernel
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_grad_kernel(grads: Tensor, acts: Tensor, denom: Tensor, sigma: float, alphas: Tensor, betas: Tensor, logll: Tensor, xlen: Tensor, ylen: Tensor, mlabels: Tensor, minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, big_blank_duration: Tensor, num_big_blanks: int, fastemit_lambda: float, clamp: float)
Compute gradients for multi-blank transducer loss
(https://arxiv.org/pdf/2211.03541).
- Parameters:
- grads – Zero Tensor of shape [B, T, U, V + 1 + num_big_blanks]. Is updated by this kernel to contain the gradients of this batch of samples.
- acts – Tensor of shape [B, T, U, V + 1 + num_big_blanks] flattened. Represents the logprobs activation tensor.
- denom – Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary.
- sigma – Hyper-parameter for logit-undernormalization technique for training multi-blank transducers.
- alphas – Alpha variable, contains forward probabilities. A tensor of shape [B, T, U].
- betas – Beta varoable, contains backward probabilities. A tensor of shape [B, T, U].
- logll – Log-likelihood of the forward variable, represented as a vector of shape [B]. Represents the log-likelihood of the forward pass.
- xlen – Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor.
- ylen – Vector of length B which contains the actual target sequence lengths in the padded activation tensor.
- mlabels – Matrix of shape [B, U+1] (+1 here is due to <SOS> token
- usually the RNNT blank). The matrix contains the padded target transcription that must be predicted.
- minibatch – Int representing the batch size.
- maxT – The maximum possible acoustic sequence length. Represents T in the logprobs tensor.
- maxU – The maximum possible target sequence length. Represents U in the logprobs tensor.
- alphabet_size – The vocabulary dimension V+1 (inclusive of RNNT blank).
- blank – Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab.
- fastemit_lambda – Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
- clamp – Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
- big_blank_durations – Vector of supported big blank durations of the model.
- num_big_blanks – Number of big blanks of the model.
Updates: : Kernel inplace updates the following inputs:
- grads: Gradients with respect to the log likelihood (logll).