Source code for espnet2.asr_transducer.decoder.modules.rwkv.feed_forward

"""Feed-forward (channel mixing) module for RWKV block.

Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py

Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.

"""  # noqa

from typing import List, Optional, Tuple

import torch


[docs]class FeedForward(torch.nn.Module): """FeedForward module definition. Args: size: Input/Output size. hidden_size: Hidden size. block_id: Block index. num_blocks: Number of blocks in the architecture. """ def __init__( self, size: int, hidden_size: int, block_id: int, num_blocks: int ) -> None: """Construct a FeedForward object.""" super().__init__() self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) self.proj_key = torch.nn.Linear(size, hidden_size, bias=True) self.proj_value = torch.nn.Linear(hidden_size, size, bias=True) self.proj_receptance = torch.nn.Linear(size, size, bias=True) self.block_id = block_id self.reset_parameters(size, block_id, num_blocks)
[docs] def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None: """Reset module parameters. Args: size: Block size. block_id: Block index. num_blocks: Number of blocks in the architecture. """ ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) time_weight = torch.ones(1, 1, size) for i in range(size): time_weight[0, 0, i] = i / size with torch.no_grad(): self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
[docs] def forward( self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: """Compute channel mixing. Args: x: FeedForward input sequences. (B, U, size) state: Decoder hidden state. [5 x (B, 1, size, N)] Returns: x: FeedForward output sequences. (B, U, size) state: Decoder hidden state. [5 x (B, 1, size, N)] """ shifted_x = ( self.time_shift(x) if state is None else state[0][..., self.block_id] ) key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) receptance = x * self.time_mix_receptance + shifted_x * ( 1 - self.time_mix_receptance ) key = torch.square(torch.relu(self.proj_key(key))) value = self.proj_value(key) receptance = torch.sigmoid(self.proj_receptance(receptance)) if state is not None: state[0][..., self.block_id] = x x = receptance * value return x, state