Source code for espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Optional, Tuple

import torch
from numba import cuda

from espnet2.asr.transducer.rnnt_multi_blank.utils import global_constants

threshold = global_constants.THRESHOLD


[docs]@cuda.jit(device=True, inline=True) def log_sum_exp(a: float, b: float): if a == global_constants.FP32_NEG_INF: return b if b == global_constants.FP32_NEG_INF: return a if a > b: return math.log1p(math.exp(b - a)) + a else: return math.log1p(math.exp(a - b)) + b
[docs]@cuda.jit(device=True, inline=True) def div_up(x: int, y: int): return (x + y - 1) // y
[docs]@cuda.jit(device=True) def maximum(x, y): if x < y: return y else: return x
[docs]@cuda.jit(device=True) def add(x, y): return x + y
[docs]@cuda.jit(device=True) def identity(x): return x
[docs]@cuda.jit(device=True) def negate(x): return -x
[docs]@cuda.jit(device=True) def exponential(x): return math.exp(x)
[docs]@cuda.jit(device=True) def log_plus(p1: float, p2: float): if p1 == global_constants.FP32_NEG_INF: return p2 if p2 == global_constants.FP32_NEG_INF: return p1 result = math.log1p(math.exp(-math.fabs(p1 - p2))) + maximum(p1, p2) return result
[docs]@cuda.jit(device=True, inline=True) def copy_data_1d(source: torch.Tensor, dest: torch.Tensor, idx: int): dest[idx] = source[idx]
[docs]@cuda.jit() def compute_costs_data( source: torch.Tensor, dest: torch.Tensor, fastemit_lambda: float ): block = cuda.blockIdx.x tid = cuda.threadIdx.x idx = block * cuda.blockDim.x + tid length = source.shape[0] if idx < length: copy_data_1d(source, dest, idx) dest[idx] *= -1.0 dest[idx] *= 1.0 + fastemit_lambda
[docs]def get_workspace_size( maxT: int, maxU: int, minibatch: int, gpu: bool ) -> Tuple[Optional[int], global_constants.RNNTStatus]: if minibatch <= 0 or maxT <= 0 or maxU <= 0: return (None, global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE) # per minibatch memory per_minibatch_size = 0 # alphas & betas per_minibatch_size += maxT * maxU * 2 if not gpu: # // blank & label log probability cache per_minibatch_size += maxT * maxU * 2 else: # // softmax denominator per_minibatch_size += maxT * maxU # // forward - backward loglikelihood per_minibatch_size += 2 size = per_minibatch_size * minibatch return (size, global_constants.RNNTStatus.RNNT_STATUS_SUCCESS)
[docs]def flatten_tensor(x: torch.Tensor): original_shape = x.shape x = x.view([-1]) return x, original_shape