Source code for espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
from numba import cuda

from espnet2.asr.transducer.rnnt_multi_blank.utils import rnnt_helper

GPU_RNNT_THREAD_SIZE = 256


[docs]@cuda.jit(device=True, inline=True) def logp( denom: torch.Tensor, acts: torch.Tensor, maxT: int, maxU: int, alphabet_size: int, mb: int, t: int, u: int, v: int, ): """Compute the sum of log probability from the activation tensor and its denominator. Args: denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). mb: Batch indexer. t: Acoustic sequence timestep indexer. u: Target sequence timestep indexer. v: Vocabulary token indexer. Returns: The sum of logprobs[mb, t, u, v] + denom[mb, t, u] """ col = (mb * maxT + t) * maxU + u return denom[col] + acts[col * alphabet_size + v]
[docs]@cuda.jit() def compute_alphas_kernel( acts: torch.Tensor, denom: torch.Tensor, alphas: torch.Tensor, llForward: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, # [B] minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, ): """Compute alpha (forward variable) probabilities over the transduction step. Args: acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable probabilities. llForward: Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. Returned as the forward pass loss that is reduced by the optimizer. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. Updates: Kernel inplace updates the following inputs: - alphas: forward variable scores. - llForward: log-likelihood of forward variable. """ # // launch B blocks, each block has U threads b = cuda.blockIdx.x # // batch id u = cuda.threadIdx.x # label id, u T = xlen[b] # select AM length of current sample U = ylen[b] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[ b ] # mb label start point, equivalent to mlabels + b * (maxU - 1) offset = b * maxT * maxU # pointer indexing offset # alphas += offset # pointer offset, ignored since we explicitly add offset # Initilize alpha[b, t=0, u=0] for all b in B if u == 0: alphas[offset] = 0 # sync until all alphas are initialized cuda.syncthreads() # Ordinary alpha calculations, broadcast across B=b and U=u # Look up forward variable calculation from rnnt_numpy.forward_pass() for n in range(1, T + U - 1): t = n - u if u == 0: # for t in range(1, T) step to initialize alphas[b, t, 0] if t > 0 and t < T: alphas[offset + t * maxU + u] = alphas[ offset + (t - 1) * maxU + u ] + logp(denom, acts, maxT, maxU, alphabet_size, b, t - 1, 0, blank_) elif u < U: # for u in range(1, U) step to initialize alphas[b, 0, u] if t == 0: alphas[offset + u] = alphas[offset + u - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, 0, u - 1, labels[u - 1] ) # for t in range(1, T) for u in range(1, U) step to compute alphas[b, t, u] elif t > 0 and t < T: no_emit = alphas[offset + (t - 1) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t - 1, u, blank_ ) emit = alphas[offset + t * maxU + u - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1] ) alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # sync across all B=b and U=u cuda.syncthreads() # After final sync, alphas[b, T-1, U - 1] + logprobs[b, T-1, U-1, blank] # + denom[b, T-1, U-1] gives log-likelihood of forward pass. if u == 0: loglike = alphas[offset + (T - 1) * maxU + U - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_ ) llForward[b] = loglike
[docs]@cuda.jit() def compute_betas_kernel( acts: torch.Tensor, denom: torch.Tensor, betas: torch.Tensor, llBackward: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, # [B, U] minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, ): """Compute beta (backward variable) probabilities over the transduction step. Args: acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. betas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the backward variable probabilities. llBackward: Zero tensor of shape [B]. Represents the log-likelihood of the backward pass. Returned as the backward pass loss that is reduced by the optimizer. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. Updates: Kernel inplace updates the following inputs: - betas: backward variable scores. - llBackward: log-likelihood of backward variable. """ # // launch B blocks, each block has U threads b = cuda.blockIdx.x # // batch id u = cuda.threadIdx.x # label id, u T = xlen[b] # select AM length of current sample U = ylen[b] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[ b ] # mb label start point, equivalent to mlabels + b * (maxU - 1) offset = b * maxT * maxU # pointer indexing offset # betas += offset # pointer offset, ignored since we explicitly add offset # Initilize beta[b, t=T-1, u=U-1] for all b in B # with log_probs[b, t=T-1, u=U-1, blank] if u == 0: betas[offset + (T - 1) * maxU + U - 1] = logp( denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_ ) # sync until all betas are initialized cuda.syncthreads() # Ordinary beta calculations, broadcast across B=b and U=u # Look up backward variable calculation from rnnt_numpy.backward_pass() for n in range(T + U - 2, -1, -1): t = n - u if u == (U - 1): # for t in reversed(range(T - 1)) step to initialize betas[b, t, U-1] if t >= 0 and t < (T - 1): betas[offset + t * maxU + U - 1] = betas[ offset + (t + 1) * maxU + U - 1 ] + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) elif u < U: if t == T - 1: # for u in reversed(range(U - 1)) step to initialize betas[b, T-1, u] betas[offset + (T - 1) * maxU + u] = betas[ offset + (T - 1) * maxU + u + 1 ] + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) elif (t >= 0) and (t < T - 1): # for t in reversed(range(T - 1)) for u in reversed(range(U - 1)) # step to compute betas[b, t, u] no_emit = betas[offset + (t + 1) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_ ) emit = betas[offset + t * maxU + u + 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, u, labels[u] ) betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # sync across all B=b and U=u cuda.syncthreads() # After final sync, betas[b, 0, 0] gives # log-likelihood of backward pass. if u == 0: llBackward[b] = betas[offset]
[docs]@cuda.jit() def compute_grad_kernel( grads: torch.Tensor, acts: torch.Tensor, denom: torch.Tensor, alphas: torch.Tensor, betas: torch.Tensor, logll: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, # [B, U] minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, fastemit_lambda: float, clamp: float, ): """Compute gradients over the transduction step. Args: grads: Zero Tensor of shape [B, T, U, V+1]. Is updated by this kernel to contain the gradients of this batch of samples. acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. alphas: Alpha variable, contains forward probabilities. A tensor of shape [B, T, U]. betas: Beta varoable, contains backward probabilities. A tensor of shape [B, T, U]. logll: Log-likelihood of the forward variable, represented as a vector of shape [B]. Represents the log-likelihood of the forward pass. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. Updates: Kernel inplace updates the following inputs: - grads: Gradients with respect to the log likelihood (logll). """ # Kernel call: # blocks_per_grid = minibatch (b) * maxT (t) * maxU (u) # threads_per_block = constant buffer size of parallel threads (v :: Constant) tid = cuda.threadIdx.x # represents v, taking steps of some constant size idx = tid # index of v < V+1; in steps of constant buffer size col = cuda.blockIdx.x # represents a fused index of b * t * u # Decompose original indices from fused `col` u = col % maxU # (b * t * u) % u = u bt = (col - u) // maxU # (b * t * u - u) // U = b * t t = bt % maxT # (b * t) % t = t mb = (bt - t) // maxT # (b * t - t) // T = b # constants T = xlen[mb] # select AM length of current sample U = ylen[mb] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[mb] # labels = mlabels + mb * (maxU - 1); # Buffered gradient calculations, broadcast across B=b, T=t and U=u, # looped over V with some constant stride. # Look up gradient calculation from rnnt_numpy.compute_gradient() if t < T and u < U: # For cuda kernels, maximum number of threads per block is limited to some value # However, it may be the case that vocabulary size is larger than this limit # To work around this, an arbitrary thread buffer size is chosen such that, # 1) each element within the thread pool operates independently of the other # 2) An inner while loop moves the index of each buffer element by the size # of the buffer itself, such that all elements of the vocabulary size are # covered in (V + 1 // thread_buffer) number of steps. # As such, each thread will perform the while loop at least # (V + 1 // thread_buffer) number of times while idx < alphabet_size: # remember, `col` represents the tri-index [b, t, u] # therefore; logpk = denom[b, t, u] + acts[b, t, u, v] logpk = denom[col] + acts[col * alphabet_size + idx] # initialize the grad of the sample acts[b, t, u, v] grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) # If FastEmit regularization is enabled, calculate the gradeint of # probability of predicting the next label at the current timestep. # The formula for this is Equation 9 in https://arxiv.org/abs/2010.11148, # multiplied by the log probability of the current step (t, u), # normalized by the total log likelihood. Once the gradient has been # calculated, scale it by `fastemit_lambda`, as in Equation 10. if fastemit_lambda > 0.0 and u < U - 1: fastemit_grad = fastemit_lambda * math.exp( alphas[col] # alphas(t, u) + ( denom[col] + acts[col * alphabet_size + labels[u]] ) # y_hat(t, u) + betas[col + 1] # betas(t, u+1) + logpk # log Pr(k|t, u) - logll[mb] # total log likelihood for normalization ) else: fastemit_grad = 0.0 # Update the gradient of act[b, t, u, v] with the gradient from # FastEmit regularization grad = grad + fastemit_grad # // grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) if (idx == blank_) and (t == T - 1) and (u == U - 1): grad -= math.exp(alphas[col] + logpk - logll[mb]) # grad of blank across t < T; # grad[b, t<T-1, u, v=blank] -= exp(alphas[b, t, u] # + logpk - logll[b] betas[b, t + 1, u]) if (idx == blank_) and (t < T - 1): grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU]) # grad of correct token across u < U; # grad[b, t, u<U-1, v=label[u]] -= exp(alphas[b, t, u] # + logpk - logll[b] + betas[b, t, u+1]) # Scale the gradient by (1.0 + FastEmit_lambda) in log space, # then exponentiate if (u < U - 1) and (idx == labels[u]): # exp(log(1 + fastemit_lambda) + ...) is numerically more stable than # multiplying (1.0 + fastemit_lambda) with result. grad -= math.exp( math.log1p(fastemit_lambda) + alphas[col] + logpk - logll[mb] + betas[col + 1] ) # update grads[b, t, u, v] = grad grads[col * alphabet_size + idx] = grad # clamp gradient (if needed) if clamp > 0.0: g = grads[col * alphabet_size + idx] g = min(g, clamp) g = max(g, -clamp) grads[col * alphabet_size + idx] = g # update internal index through the thread_buffer; # until idx < V + 1, such that entire vocabulary has been updated. idx += GPU_RNNT_THREAD_SIZE
[docs]@cuda.jit() def compute_multiblank_alphas_kernel( acts: torch.Tensor, denom: torch.Tensor, sigma: float, alphas: torch.Tensor, llForward: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, big_blank_duration: torch.Tensor, num_big_blanks: int, ): """Compute alpha (forward variable) probabilities for multi-blank transducuer loss (https://arxiv.org/pdf/2211.03541). Args: acts: Tensor of shape [B, T, U, V + 1 + num_big_blanks] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. sigma: Hyper-parameter for logit-undernormalization technique for training multi-blank transducers. alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable probabilities. llForward: Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. Returned as the forward pass loss that is reduced by the optimizer. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT standard blank token in the vocabulary. big_blank_durations: Vector of supported big blank durations of the model. num_big_blanks: Number of big blanks of the model. Updates: Kernel inplace updates the following inputs: - alphas: forward variable scores. - llForward: log-likelihood of forward variable. """ # // launch B blocks, each block has U threads b = cuda.blockIdx.x # // batch id u = cuda.threadIdx.x # label id, u T = xlen[b] # select AM length of current sample U = ylen[b] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[ b ] # mb label start point, equivalent to mlabels + b * (maxU - 1) offset = b * maxT * maxU # pointer indexing offset # Initilize alpha[b, t=0, u=0] for all b in B if u == 0: alphas[offset] = 0 # sync until all alphas are initialized cuda.syncthreads() # Ordinary alpha calculations, broadcast across B=b and U=u # Look up forward variable calculation from rnnt_numpy.forward_pass() # Note: because of the logit under-normalization, everytime logp() is called, # it is always followed by a `-sigma` term. for n in range(1, T + U - 1): t = n - u if u == 0: # for t in range(1, T) step to initialize alphas[b, t, 0] if t > 0 and t < T: alphas[offset + t * maxU + u] = ( alphas[offset + (t - 1) * maxU + u] + logp(denom, acts, maxT, maxU, alphabet_size, b, t - 1, 0, blank_) - sigma ) # Now add the weights for big blanks. for i in range(num_big_blanks): if t >= big_blank_duration[i]: alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( alphas[offset + t * maxU + u], alphas[offset + (t - big_blank_duration[i]) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], 0, blank_ - 1 - i, ) - sigma, ) elif u < U: # for u in range(1, U) step to initialize alphas[b, 0, u] if t == 0: alphas[offset + u] = ( alphas[offset + u - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, 0, u - 1, labels[u - 1], ) - sigma ) # for t in range(1, T) for u in range(1, U) step to compute alphas[b, t, u] elif t > 0 and t < T: no_emit = ( alphas[offset + (t - 1) * maxU + u] + logp(denom, acts, maxT, maxU, alphabet_size, b, t - 1, u, blank_) - sigma ) emit = ( alphas[offset + t * maxU + u - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1], ) - sigma ) alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # Now add the weights for big blanks. for i in range(num_big_blanks): if t >= big_blank_duration[i]: # big-blank weight here is # alpha(t - duration, u) * p(big-blank | t - duration, u) # / exp(sigma), in log domain # do this all all big-blanks if the above condition is met big_blank_no_emit = ( alphas[offset + (t - big_blank_duration[i]) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], u, blank_ - 1 - i, ) - sigma ) alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( alphas[offset + t * maxU + u], big_blank_no_emit ) # sync across all B=b and U=u cuda.syncthreads() # After final sync, alphas[b, T-1, U - 1] + logprobs[b, T-1, U-1, blank] # + denom[b, T-1, U-1] gives log-likelihood of forward pass. if u == 0: loglike = ( alphas[offset + (T - 1) * maxU + U - 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) - sigma ) # Now add the weights for big blanks for the final weight computation. for i in range(num_big_blanks): if T >= big_blank_duration[i]: big_blank_loglike = ( alphas[offset + (T - big_blank_duration[i]) * maxU + U - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, T - big_blank_duration[i], U - 1, blank_ - 1 - i, ) - sigma ) loglike = rnnt_helper.log_sum_exp(loglike, big_blank_loglike) llForward[b] = loglike
[docs]@cuda.jit() def compute_multiblank_betas_kernel( acts: torch.Tensor, denom: torch.Tensor, sigma: float, betas: torch.Tensor, llBackward: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, # [B, U] minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, big_blank_duration: torch.Tensor, num_big_blanks: int, ): """Compute beta (backward variable) probabilities for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541). Args: acts: Tensor of shape [B, T, U, V + 1 + num-big-blanks] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. sigma: Hyper-parameter for logit-undernormalization technique for training multi-blank transducers. betas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the backward variable probabilities. llBackward: Zero tensor of shape [B]. Represents the log-likelihood of the backward pass. Returned as the backward pass loss that is reduced by the optimizer. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT standard blank token in the vocabulary. big_blank_durations: Vector of supported big blank durations of the model. num_big_blanks: Number of big blanks of the model. Updates: Kernel inplace updates the following inputs: - betas: backward variable scores. - llBackward: log-likelihood of backward variable. """ # // launch B blocks, each block has U threads b = cuda.blockIdx.x # // batch id u = cuda.threadIdx.x # label id, u T = xlen[b] # select AM length of current sample U = ylen[b] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[ b ] # mb label start point, equivalent to mlabels + b * (maxU - 1) offset = b * maxT * maxU # pointer indexing offset # Note: just like the alphas, because of the logit under-normalization, everytime # logp() is called, it is always followed by a `-sigma` term. # Initilize beta[b, t=T-1, u=U-1] for all b in B with # log_probs[b, t=T-1, u=U-1, blank] if u == 0: betas[offset + (T - 1) * maxU + U - 1] = ( logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) - sigma ) # sync until all betas are initialized cuda.syncthreads() # Ordinary beta calculations, broadcast across B=b and U=u # Look up backward variable calculation from rnnt_numpy.backward_pass() for n in range(T + U - 2, -1, -1): t = n - u if u == (U - 1): # for t in reversed(range(T - 1)) step to initialize betas[b, t, U-1] if t >= 0 and t < (T - 1): # beta[t, U - 1] = beta[t + 1, U - 1] * p(blank | t, U - 1) / exp(sigma) # this part is the same as regular RNN-T. betas[offset + t * maxU + U - 1] = ( betas[offset + (t + 1) * maxU + U - 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) - sigma ) # now add the weights from big blanks for i in range(num_big_blanks): if t + big_blank_duration[i] < T: # adding to beta[t, U - 1] of weight (in log domain), # beta[t + duration, U - 1] * # p(big-blank | t, U - 1) / exp(sigma) betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], betas[offset + (t + big_blank_duration[i]) * maxU + U - 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_ - 1 - i, ) - sigma, ) elif t + big_blank_duration[i] == T and big_blank_duration[i] != 1: # adding to beta[T - duration, U - 1] of weight (in log domain), # p(big-blank | T - duration, U - 1) / exp(sigma) betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], logp( denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_ - 1 - i, ) - sigma, ) elif u < U: if t == T - 1: # for u in reversed(range(U - 1)) step to initialize betas[b, T-1, u] betas[offset + (T - 1) * maxU + u] = ( betas[offset + (T - 1) * maxU + u + 1] + logp( denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u] ) - sigma ) elif (t >= 0) and (t < T - 1): # for t in reversed(range(T - 1)) for u in reversed(range(U - 1)) # step to compute betas[b, t, u] no_emit = ( betas[offset + (t + 1) * maxU + u] + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_) - sigma ) emit = ( betas[offset + t * maxU + u + 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u, labels[u]) - sigma ) betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # now add the weights from big blanks for i in range(num_big_blanks): if t < T - big_blank_duration[i]: # added weight for the big-blank, # beta[t + duration, u] * p(big-blank | t, u) / exp(sigma) big_blank_no_emit = ( betas[offset + (t + big_blank_duration[i]) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_ - 1 - i, ) - sigma ) betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + u], big_blank_no_emit ) # sync across all B=b and U=u cuda.syncthreads() # After final sync, betas[b, 0, 0] gives # log-likelihood of backward pass. if u == 0: llBackward[b] = betas[offset]
[docs]@cuda.jit() def compute_multiblank_grad_kernel( grads: torch.Tensor, acts: torch.Tensor, denom: torch.Tensor, sigma: float, alphas: torch.Tensor, betas: torch.Tensor, logll: torch.Tensor, xlen: torch.Tensor, ylen: torch.Tensor, mlabels: torch.Tensor, # [B, U] minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, big_blank_duration: torch.Tensor, num_big_blanks: int, fastemit_lambda: float, clamp: float, ): """Compute gradients for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541). Args: grads: Zero Tensor of shape [B, T, U, V + 1 + num_big_blanks]. Is updated by this kernel to contain the gradients of this batch of samples. acts: Tensor of shape [B, T, U, V + 1 + num_big_blanks] flattened. Represents the logprobs activation tensor. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. sigma: Hyper-parameter for logit-undernormalization technique for training multi-blank transducers. alphas: Alpha variable, contains forward probabilities. A tensor of shape [B, T, U]. betas: Beta varoable, contains backward probabilities. A tensor of shape [B, T, U]. logll: Log-likelihood of the forward variable, represented as a vector of shape [B]. Represents the log-likelihood of the forward pass. xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor. ylen: Vector of length B which contains the actual target sequence lengths in the padded activation tensor. mlabels: Matrix of shape [B, U+1] (+1 here is due to <SOS> token - usually the RNNT blank). The matrix contains the padded target transcription that must be predicted. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. big_blank_durations: Vector of supported big blank durations of the model. num_big_blanks: Number of big blanks of the model. Updates: Kernel inplace updates the following inputs: - grads: Gradients with respect to the log likelihood (logll). """ # Kernel call: # blocks_per_grid = minibatch (b) * maxT (t) * maxU (u) # threads_per_block = constant buffer size of parallel threads (v :: Constant) tid = cuda.threadIdx.x # represents v, taking steps of some constant size idx = tid # index of v < V+1; in steps of constant buffer size col = cuda.blockIdx.x # represents a fused index of b * t * u # Decompose original indices from fused `col` u = col % maxU # (b * t * u) % u = u bt = (col - u) // maxU # (b * t * u - u) // U = b * t t = bt % maxT # (b * t) % t = t mb = (bt - t) // maxT # (b * t - t) // T = b # constants T = xlen[mb] # select AM length of current sample U = ylen[mb] + 1 # select target length of current sample, +1 for the blank token labels: torch.Tensor = mlabels[mb] # labels = mlabels + mb * (maxU - 1); # Buffered gradient calculations, broadcast across B=b, T=t and U=u, looped over # V with some constant stride. Look up gradient calculation from # rnnt_numpy.compute_gradient() if t < T and u < U: # For cuda kernels, maximum number of threads per block is limited to some value # However, it may be the case that vocabulary size is larger than this limit # To work around this, an arbitrary thread buffer size is chosen such that, # 1) each element within the thread pool operates independently of the other # 2) An inner while loop moves the index of each buffer element by the size # of the buffer itself, such that all elements of the vocabulary size are # covered in (V + 1 // thread_buffer) number of steps. # As such, each thread will perform the while loop at least # (V + 1 // thread_buffer) number of times while idx < alphabet_size: # remember, `col` represents the tri-index [b, t, u] # therefore; logpk = denom[b, t, u] + acts[b, t, u, v] logpk = denom[col] + acts[col * alphabet_size + idx] # initialize the grad of the sample acts[b, t, u, v] grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) # In all of the following computation, whenever logpk is used, we # need to subtract sigma based on our derivation of the gradient of # the logit under-normalization method. # If FastEmit regularization is enabled, calculate the gradeint of # probability of predicting the next label at the current timestep. # The formula for this is Equation 9 in https://arxiv.org/abs/2010.11148, # multiplied by the log probability of the current step (t, u), normalized # by the total log likelihood. Once the gradient has been calculated, # scale it by `fastemit_lambda`, as in Equation 10. if fastemit_lambda > 0.0 and u < U - 1: fastemit_grad = fastemit_lambda * math.exp( alphas[col] # alphas(t, u) + (denom[col] + acts[col * alphabet_size + labels[u]]) + betas[col + 1] # betas(t, u+1) + logpk # log Pr(k|t, u) - sigma - logll[mb] # total log likelihood for normalization ) else: fastemit_grad = 0.0 # Update the gradient of act[b, t, u, v] with the gradient # from FastEmit regularization grad = grad + fastemit_grad # grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) # + logpk - sigma - logll[b]) if (idx == blank_) and (t == T - 1) and (u == U - 1): grad -= math.exp(alphas[col] + logpk - sigma - logll[mb]) else: # this is one difference of the multi-blank gradient from standard RNN-T # gradient -- basically, wherever the blank_ symbol is addressed in the # original code, we need to do similar things to big blanks, and we need # to change the if conditions to match the duration of the big-blank. # grad[b, T-duration, U-1, v=big-blank] -= # exp(alphas[b, t, u) + logpk - sigma - logll[b]) for i in range(num_big_blanks): if ( (idx == blank_ - 1 - i) and (t == T - big_blank_duration[i]) and (u == U - 1) ): grad -= math.exp(alphas[col] + logpk - sigma - logll[mb]) # grad of blank across t < T; # grad[b, t<T-1, u, v=blank] -= exp(alphas[b, t, u] + # logpk - sigma - logll[b] betas[b, t + 1, u]) if (idx == blank_) and (t < T - 1): grad -= math.exp( alphas[col] + logpk - sigma - logll[mb] + betas[col + maxU] ) else: # This is another difference between multi-blank and RNN-T gradients. # Now we consider gradients for big-blanks. # grad[b, t<T-duration, u, v=big-blank] -= # exp(alphas[b, t, u] + logpk - sigma - logll[b] # + betas[b, t + duration, u]) for i in range(num_big_blanks): if (idx == blank_ - 1 - i) and (t < T - big_blank_duration[i]): grad -= math.exp( alphas[col] + logpk - sigma - logll[mb] + betas[col + big_blank_duration[i] * maxU] ) # grad of correct token across u < U; # grad[b, t, u<U-1, v=label[u]] -= # exp(alphas[b, t, u] + logpk - sigma - logll[b] + betas[b, t, u+1]) # Scale the gradient by (1.0 + FastEmit_lambda) in log space, # then exponentiate if (u < U - 1) and (idx == labels[u]): # exp(log(1 + fastemit_lambda) + ...) is numerically more stable than # multiplying (1.0 + fastemit_lambda) with result. grad -= math.exp( math.log1p(fastemit_lambda) + alphas[col] + logpk - sigma - logll[mb] + betas[col + 1] ) # update grads[b, t, u, v] = grad grads[col * alphabet_size + idx] = grad # clamp gradient (if needed) if clamp > 0.0: g = grads[col * alphabet_size + idx] g = min(g, clamp) g = max(g, -clamp) grads[col * alphabet_size + idx] = g # update internal index through the thread_buffer; # until idx < V + 1, such that entire vocabulary has been updated. idx += GPU_RNNT_THREAD_SIZE