espnet2.asr.state_spaces.s4.SSKernel
espnet2.asr.state_spaces.s4.SSKernel
class espnet2.asr.state_spaces.s4.SSKernel(H, N=64, L=None, measure='legs', rank=1, channels=1, dt_min=0.001, dt_max=0.1, deterministic=False, lr=None, mode='nplr', n_ssm=None, verbose=False, measure_args={}, **kernel_args)
Bases: Module
Wrapper around SSKernel parameterizations.
The SSKernel is expected to support the interface forward() default_state() _setup_step() step()
State Space Kernel which computes the convolution kernel $\bar{K}$.
H: Number of independent SSM copies; : controls the size of the model. Also called d_model in the config.
N: State size (dimensionality of parameters A, B, C). : Also called d_state in the config. Generally shouldn’t need to be adjusted and doens’t affect speed much.
L: Maximum length of convolution kernel, if known. : Should work in the majority of cases even if not known.
measure: Options for initialization of (A, B). : For NPLR mode, recommendations are “legs”, “fout”, “hippo” (combination of both). For Diag mode, recommendations are “diag-inv”, “diag-lin”, “diag-legs”, and “diag” (combination of diag-inv and diag-lin)
rank: Rank of low-rank correction for NPLR mode. : Needs to be increased for measure “legt”
channels: C channels turns the SSM from a 1-dim to C-dim map; : can think of it having C separate “heads” per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
dt_min, dt_max: min and max values for the step size dt (Delta) mode: Which kernel algorithm to use. ‘nplr’ is the full S4 model;
‘diag’ is the simpler S4D; ‘slow’ is a dense version for testing
n_ssm: Number of independent trainable (A, B) SSMs, : e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn’t affect performance or speed much. This parameter must divide H
lr: Passing in a number (e.g. 0.001) sets : attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
default_state(*args, **kwargs)
forward(state=None, L=None, rate=None)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
forward_state(u, state)
Forward the state through a sequence.
i.e. computes the state after passing chunk through SSM
state: (B, H, N) u: (B, H, L)
Returns: (B, H, N)
step(u, state, **kwargs)