espnet2.asr_transducer.decoder.blocks.rwkv.RWKV
Less than 1 minute
espnet2.asr_transducer.decoder.blocks.rwkv.RWKV
class espnet2.asr_transducer.decoder.blocks.rwkv.RWKV(size: int, linear_size: int, attention_size: int, context_size: int, block_id: int, num_blocks: int, normalization_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, normalization_args: ~typing.Dict = {}, att_dropout_rate: float = 0.0, ffn_dropout_rate: float = 0.0)
Bases: Module
RWKV module.
- Parameters:
- size – Input/Output size.
- linear_size – Feed-forward hidden size.
- attention_size – SelfAttention hidden size.
- context_size – Context size for WKV computation.
- block_id – Block index.
- num_blocks – Number of blocks in the architecture.
- normalization_class – Normalization layer class.
- normalization_args – Normalization layer arguments.
- att_dropout_rate – Dropout rate for the attention module.
- ffn_dropout_rate – Dropout rate for the feed-forward module.
Construct a RWKV object.
forward(x: Tensor, state: Tensor | None = None) → Tuple[Tensor, Tensor | None]
Compute receptance weighted key value.
- Parameters:
- x – RWKV input sequences. (B, L, size)
- state – Decoder hidden states. [5 x (B, D_att/size, N)]
- Returns: RWKV output sequences. (B, L, size) x: Decoder hidden states. [5 x (B, D_att/size, N)]
- Return type: x