espnet2.asr.state_spaces.s4.SSKernelNPLR
espnet2.asr.state_spaces.s4.SSKernelNPLR
class espnet2.asr.state_spaces.s4.SSKernelNPLR(w, P, B, C, log_dt, L=None, lr=None, verbose=False, keops=False, real_type='exp', real_tolerance=0.001, bandlimit=None)
Bases: OptimModule
Stores a representation of and computes the SSKernel function.
K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)
Initialize kernel.
L: Maximum length; this module computes an SSM kernel of length L A is represented by diag(w) - PP^* w: (S, N) diagonal part P: (R, S, N) low-rank part
B: (S, N) C: (C, H, N) dt: (H) timescale per feature lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)
Dimensions: N (or d_state): state size H (or d_model): total SSM copies S (or n_ssm): number of trainable copies of (A, B, dt); must divide H R (or rank): rank of low-rank part C (or channels): system is 1-dim to C-dim
The forward pass of this Module returns a tensor of shape (C, H, L)
Note: tensor shape N here denotes half the true state size, : because of conjugate symmetry
default_state(*batch_shape)
forward(state=None, rate=1.0, L=None)
Forward pass.
state: (B, H, N) initial state rate: sampling rate factor L: target length
returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state
step(u, state)
Step one time step as a recurrent model.
Must have called self._setup_step() and created state with self.default_state() before calling this