espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention
Less than 1 minute
espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention
class espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention(*args, **kwargs)
Bases: Function
WKVLinearAttention function definition.
static backward(ctx, grad_output: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
WKVLinearAttention function backward pass.
- Parameters:grad_output – Output gradient. (B, U, D_att)
- Returns: Gradient for channel-wise time decay vector. (D_att) grad_time_first: Gradient for channel-wise time first vector. (D_att) grad_key: Gradient for key tensor. (B, U, D_att) grad_value: Gradient for value tensor. (B, U, D_att)
- Return type: grad_time_decay
static forward(ctx, time_decay: Tensor, time_first: Tensor, key: Tensor, value: tensor) → Tensor
WKVLinearAttention function forward pass.
- Parameters:
- time_decay – Channel-wise time decay vector. (D_att)
- time_first – Channel-wise time first vector. (D_att)
- key – Key tensor. (B, U, D_att)
- value – Value tensor. (B, U, D_att)
- Returns: Weighted Key-Value tensor. (B, U, D_att)
- Return type: out