espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention
Less than 1 minute
espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention
class espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention(size: int, attention_size: int, context_size: int, block_id: int, num_blocks: int)
Bases: Module
SelfAttention module definition.
- Parameters:
- size – Input/Output size.
- attention_size – Attention hidden size.
- context_size – Context size for WKV kernel.
- block_id – Block index.
- num_blocks – Number of blocks in the architecture.
Construct a SelfAttention object.
forward(x: Tensor, state: List[Tensor] | None = None) → Tuple[Tensor, List[Tensor] | None]
Compute time mixing.
- Parameters:
- x – SelfAttention input sequences. (B, U, size)
- state – Decoder hidden states. [5 x (B, 1, D_att, N)]
- Returns: SelfAttention output sequences. (B, U, size)
- Return type: x
reset_parameters(size: int, attention_size: int, block_id: int, num_blocks: int) → None
Reset module parameters.
- Parameters:
- size – Block size.
- attention_size – Attention hidden size.
- block_id – Block index.
- num_blocks – Number of blocks in the architecture.
wkv_linear_attention(time_decay: Tensor, time_first: Tensor, key: Tensor, value: Tensor, state: Tuple[Tensor, Tensor, Tensor]) → Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
Compute WKV with state (i.e.: for inference).
- Parameters:
- time_decay – Channel-wise time decay vector. (D_att)
- time_first – Channel-wise time first vector. (D_att)
- key – Key tensor. (B, 1, D_att)
- value – Value tensor. (B, 1, D_att)
- state – Decoder hidden states. [3 x (B, D_att)]
- Returns: Weighted Key-Value. (B, 1, D_att) state: Decoder hidden states. [3 x (B, 1, D_att)]
- Return type: output