Source code for espnet2.asr_transducer.decoder.modules.rwkv.attention

"""Attention (time mixing) modules for RWKV block.

Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.

Some variables are renamed according to
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.

"""

import math
from importlib.util import find_spec
from pathlib import Path
from typing import List, Optional, Tuple

import torch

wkv_kernel = None


[docs]class WKVLinearAttention(torch.autograd.Function): """WKVLinearAttention function definition."""
[docs] @staticmethod def forward( ctx, time_decay: torch.Tensor, time_first: torch.Tensor, key: torch.Tensor, value: torch.tensor, ) -> torch.Tensor: """WKVLinearAttention function forward pass. Args: time_decay: Channel-wise time decay vector. (D_att) time_first: Channel-wise time first vector. (D_att) key: Key tensor. (B, U, D_att) value: Value tensor. (B, U, D_att) Returns: out: Weighted Key-Value tensor. (B, U, D_att) """ batch, length, dim = key.size() assert length <= wkv_kernel.context_size, ( f"Cannot process key of length {length} while context_size " f"is ({wkv_kernel.context_size}). Limit should be increased." ) assert batch * dim % min(dim, 32) == 0, ( f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}" ) ctx.input_dtype = key.dtype time_decay = -torch.exp(time_decay.contiguous()) time_first = time_first.contiguous() key = key.contiguous() value = value.contiguous() out = torch.empty_like(key, memory_format=torch.contiguous_format) wkv_kernel.forward(time_decay, time_first, key, value, out) ctx.save_for_backward(time_decay, time_first, key, value, out) return out
[docs] @staticmethod def backward( ctx, grad_output: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """WKVLinearAttention function backward pass. Args: grad_output: Output gradient. (B, U, D_att) Returns: grad_time_decay: Gradient for channel-wise time decay vector. (D_att) grad_time_first: Gradient for channel-wise time first vector. (D_att) grad_key: Gradient for key tensor. (B, U, D_att) grad_value: Gradient for value tensor. (B, U, D_att) """ time_decay, time_first, key, value, output = ctx.saved_tensors grad_dtype = ctx.input_dtype # noqa batch, _, dim = key.size() grad_time_decay = torch.empty( (batch, dim), memory_format=torch.contiguous_format, dtype=time_decay.dtype, device=time_decay.device, ) grad_time_first = torch.empty( (batch, dim), memory_format=torch.contiguous_format, dtype=time_decay.dtype, device=time_decay.device, ) grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) wkv_kernel.backward( time_decay, time_first, key, value, output, grad_output.contiguous(), grad_time_decay, grad_time_first, grad_key, grad_value, ) grad_time_decay = torch.sum(grad_time_decay, dim=0) grad_time_first = torch.sum(grad_time_first, dim=0) return ( grad_time_decay, grad_time_first, grad_key, grad_value, )
[docs]def load_wkv_kernel(context_size: int) -> None: """Load WKV CUDA kernel. Args: context_size: Context size. """ from torch.utils.cpp_extension import load global wkv_kernel if wkv_kernel is not None and wkv_kernel.context_size == context_size: return if find_spec("ninja") is None: raise ImportError( "Ninja package was not found. WKV kernel module can't be loaded " "for training. Please, 'pip install ninja' in your environment." ) if not torch.cuda.is_available(): raise ImportError( "CUDA is currently a requirement for WKV kernel loading. " "Please set your devices properly and launch again." ) kernel_folder = Path(__file__).resolve().parent / "cuda" kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] kernel_cflags = [ "-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={context_size}", ] wkv_kernel = load( name=f"wkv_{context_size}", sources=kernel_files, verbose=False, extra_cuda_cflags=kernel_cflags, ) wkv_kernel.context_size = context_size
[docs]class SelfAttention(torch.nn.Module): """SelfAttention module definition. Args: size: Input/Output size. attention_size: Attention hidden size. context_size: Context size for WKV kernel. block_id: Block index. num_blocks: Number of blocks in the architecture. """ def __init__( self, size: int, attention_size: int, context_size: int, block_id: int, num_blocks: int, ) -> None: """Construct a SelfAttention object.""" super().__init__() load_wkv_kernel(context_size) self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) self.time_decay = torch.nn.Parameter(torch.empty(attention_size)) self.time_first = torch.nn.Parameter(torch.empty(attention_size)) self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size)) self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) self.proj_key = torch.nn.Linear(size, attention_size, bias=True) self.proj_value = torch.nn.Linear(size, attention_size, bias=True) self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True) self.proj_output = torch.nn.Linear(attention_size, size, bias=True) self.block_id = block_id self.reset_parameters(size, attention_size, block_id, num_blocks)
[docs] def reset_parameters( self, size: int, attention_size: int, block_id: int, num_blocks: int ) -> None: """Reset module parameters. Args: size: Block size. attention_size: Attention hidden size. block_id: Block index. num_blocks: Number of blocks in the architecture. """ ratio_0_to_1 = block_id / (num_blocks - 1) ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) time_weight = torch.ones(1, 1, size) for i in range(size): time_weight[0, 0, i] = i / size decay_speed = [ -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) for h in range(attention_size) ] decay_speed = torch.tensor( decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device ) zigzag = ( torch.tensor( [(i + 1) % 3 - 1 for i in range(attention_size)], dtype=self.time_first.dtype, device=self.time_first.device, ) * 0.5 ) with torch.no_grad(): self.time_decay.data = decay_speed self.time_first.data = torch.ones_like( self.time_first * math.log(0.3) + zigzag ) self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) self.time_mix_value.data = ( torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 ) self.time_mix_receptance.data = torch.pow( time_weight, 0.5 * ratio_1_to_almost0 )
[docs] @torch.no_grad() def wkv_linear_attention( self, time_decay: torch.Tensor, time_first: torch.Tensor, key: torch.Tensor, value: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute WKV with state (i.e.: for inference). Args: time_decay: Channel-wise time decay vector. (D_att) time_first: Channel-wise time first vector. (D_att) key: Key tensor. (B, 1, D_att) value: Value tensor. (B, 1, D_att) state: Decoder hidden states. [3 x (B, D_att)] Returns: output: Weighted Key-Value. (B, 1, D_att) state: Decoder hidden states. [3 x (B, 1, D_att)] """ num_state, den_state, max_state = state time_decay = -torch.exp(time_decay) max_for_output = torch.maximum(max_state, (time_first + key)) e1 = torch.exp(max_state - max_for_output) e2 = torch.exp((time_first + key) - max_for_output) numerator = e1 * num_state + e2 * value denominator = e1 * den_state + e2 max_for_state = torch.maximum(key, (max_state + time_decay)) e1 = torch.exp((max_state + time_decay) - max_for_state) e2 = torch.exp(key - max_for_state) wkv = numerator / denominator state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state] return wkv, state
[docs] def forward( self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: """Compute time mixing. Args: x: SelfAttention input sequences. (B, U, size) state: Decoder hidden states. [5 x (B, 1, D_att, N)] Returns: x: SelfAttention output sequences. (B, U, size) """ shifted_x = ( self.time_shift(x) if state is None else state[1][..., self.block_id] ) key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) receptance = x * self.time_mix_receptance + shifted_x * ( 1 - self.time_mix_receptance ) key = self.proj_key(key) value = self.proj_value(value) receptance = torch.sigmoid(self.proj_receptance(receptance)) if state is not None: state[1][..., self.block_id] = x wkv, att_state = self.wkv_linear_attention( self.time_decay, self.time_first, key, value, tuple(s[..., self.block_id] for s in state[2:]), ) state[2][..., self.block_id] = att_state[0] state[3][..., self.block_id] = att_state[1] state[4][..., self.block_id] = att_state[2] else: wkv = WKVLinearAttention.apply(self.time_decay, self.time_first, key, value) x = self.proj_output(receptance * wkv) return x, state