Source code for espnet2.asr.state_spaces.s4

# This code is derived from https://github.com/HazyResearch/state-spaces

"""Standalone version of Structured (Sequence) State Space (S4) model."""

import logging
import math
import os
from functools import wraps

# from pytorch_lightning.utilities import rank_zero_only
from typing import Any, Callable, Optional

import numpy as np
import opt_einsum as oe
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from espnet2.asr.state_spaces.components import Activation, DropoutNd, LinearActivation

contract = oe.contract
contract_expression = oe.contract_expression


[docs]def rank_zero_only(fn: Callable) -> Callable: """Decorator function from PyTorch Lightning. Function that can be used as a decorator to enable a function/method being called only on global rank 0. """ @wraps(fn) def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: return fn(*args, **kwargs) return None return wrapped_fn
def _get_rank() -> int: # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: return int(rank) return 0 # add the attribute to the function but don't overwrite # in case Trainer has already set it rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
[docs]def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initialize multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ( "debug", "info", "warning", "error", "exception", "fatal", "critical", ): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger
log = get_logger(__name__) """ Cauchy and Vandermonde kernels """ try: # Try CUDA extension from .cauchy import cauchy_mult has_cauchy_extension = True except ImportError: log.warning( "CUDA extension for cauchy multiplication not found." " Please install it via `cd /path/to/espnet/tools && . ./activate_python.sh" " && ./installers/install_cauchy_mult.sh`." " This should speed up end-to-end training by 10-50%" ) has_cauchy_extension = False try: # Try pykeops import pykeops # noqa from pykeops.torch import Genred has_pykeops = True log.info("Pykeops installation found.") def _broadcast_dims(*tensors): max_dim = max([len(tensor.shape) for tensor in tensors]) tensors = [ tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape) for tensor in tensors ] return tensors def cauchy_conj(v, z, w): """Pykeops version.""" expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))" expr_denom = "ComplexMult(z-w, z-Conj(w))" cauchy_mult = Genred( f"ComplexDivide({expr_num}, {expr_denom})", [ "v = Vj(2)", "z = Vi(2)", "w = Vj(2)", ], reduction_op="Sum", axis=1, ) v, z, w = _broadcast_dims(v, z, w) v = _c2r(v) z = _c2r(z) w = _c2r(w) r = 2 * cauchy_mult(v, z, w, backend="GPU") return _r2c(r) def log_vandermonde(v, x, L): expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))" vandermonde_mult = Genred( expr, [ "v = Vj(2)", "x = Vj(2)", "l = Vi(2)", ], reduction_op="Sum", axis=1, ) length = torch.arange(L).to(x) v, x, length = _broadcast_dims(v, x, length) v = _c2r(v) x = _c2r(x) length = _c2r(length) r = vandermonde_mult(v, x, length, backend="GPU") return 2 * _r2c(r).real def log_vandermonde_transpose(u, v, x, L): """Compute Vandermonde product. u: ... H L v: ... H N x: ... H N Returns: ... H N V = Vandermonde(a, L) : (H N L) contract_L(V * u * v) """ expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))" vandermonde_mult = Genred( expr, [ "u = Vj(2)", "v = Vi(2)", "x = Vi(2)", "l = Vj(2)", ], reduction_op="Sum", axis=1, ) length = torch.arange(L).to(x) u, v, x, length = _broadcast_dims(u, v, x, length) u = _c2r(u) v = _c2r(v) x = _c2r(x) length = _c2r(length) r = vandermonde_mult(u, v, x, length, backend="GPU") return _r2c(r) except ImportError: has_pykeops = False if not has_cauchy_extension: log.warning( "Falling back on slow Cauchy kernel. " "Install at least one of pykeops or the CUDA extension for efficiency." )
[docs] def cauchy_naive(v, z, w): """Naive version. v, w: (..., N) z: (..., L) returns: (..., L) """ cauchy_matrix = v.unsqueeze(-1) / ( z.unsqueeze(-2) - w.unsqueeze(-1) ) # (... N L) return torch.sum(cauchy_matrix, dim=-2)
# Vandermonde functions log.warning( "Falling back on slow Vandermonde kernel. " "Install pykeops for improved memory efficiency." )
[docs] def log_vandermonde(v, x, L): r"""Compute Vandermonde product. v: (..., N) x: (..., N) returns: (..., L) \sum v x^l """ vandermonde_matrix = torch.exp( x.unsqueeze(-1) * torch.arange(L).to(x) ) # (... N L) vandermonde_prod = contract( "... n, ... n l -> ... l", v, vandermonde_matrix ) # (... L) return 2 * vandermonde_prod.real
[docs] def log_vandermonde_transpose(u, v, x, L): vandermonde_matrix = torch.exp( x.unsqueeze(-1) * torch.arange(L).to(x) ) # (... N L) vandermonde_prod = contract( "... l, ... n, ... n l -> ... n", u.to(x), v.to(x), vandermonde_matrix ) # (... L) return vandermonde_prod
def _conj(x): return torch.cat([x, x.conj()], dim=-1) _c2r = torch.view_as_real _r2c = torch.view_as_complex if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): def _resolve_conj(x): return x.conj().resolve_conj() else: def _resolve_conj(x): return x.conj() """ Misc functional utilities """
[docs]def power(L, A, v=None): """Compute A^L and the scan sum_i A^i v_i. A: (..., N, N) v: (..., N, L) """ E = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) powers = [A] length = 1 while True: if L % 2 == 1: E = powers[-1] @ E L //= 2 if L == 0: break length *= 2 powers.append(powers[-1] @ powers[-1]) if v is None: return E # Invariants: # powers[-1] := A^length # length := largest po2 at most L # Note that an alternative divide and conquer to compute the reduction is possible # and can be embedded into the above loop without caching intermediate powers of A # We do this reverse divide-and-conquer for efficiency reasons: # 1) it involves fewer padding steps for non-po2 L # 2) it involves more contiguous arrays # Take care of edge case for non-po2 arrays # Note that this initial step is a no-op for the case of power of 2 (length == L) k = v.size(-1) - length v_ = powers.pop() @ v[..., length:] v = v[..., :length] v[..., :k] = v[..., :k] + v_ # Handle reduction for power of 2 while v.size(-1) > 1: v = rearrange(v, "... (z l) -> ... z l", z=2) v = v[..., 0, :] + powers.pop() @ v[..., 1, :] return E, v.squeeze(-1)
""" HiPPO utilities """
[docs]def transition(measure, N): """A, B transition matrices for different measures.""" # Legendre (translated) if measure == "legt": Q = np.arange(N, dtype=np.float64) R = (2 * Q + 1) ** 0.5 j, i = np.meshgrid(Q, Q) A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] B = R[:, None] A = -A # Halve again for timescale correctness A *= 0.5 B *= 0.5 # Legendre (scaled) elif measure == "legs": q = np.arange(N, dtype=np.float64) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." # after torch.as_tensor(B) elif measure == "legsd": # Essentially equivalent to S4D-LegS q = np.arange(N, dtype=np.float64) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." # after torch.as_tensor(B) A += 0.5 * B * B[None, :, 0] B = B / 2.0 elif measure in ["fourier_diag", "foud"]: # Essentially equivalent to S4D-Lin freqs = np.arange(N // 2) d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) A = A - 0.5 * np.eye(N) B = np.zeros(N) B[0::2] = 2**0.5 B[0] = 1 B = B[:, None] elif measure in ["fourier", "fout"]: freqs = np.arange(N // 2) d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) B = np.zeros(N) B[0::2] = 2**0.5 B[0] = 1 # Subtract off rank correction - this corresponds # to the other endpoint u(t-1) in this case A = A - B[:, None] * B[None, :] B = B[:, None] else: raise NotImplementedError return A, B
[docs]def rank_correction(measure, N, rank=1, dtype=torch.float): """Return low-rank matrix L such that A + L is normal.""" if measure == "legs": assert rank >= 1 P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) elif measure == "legt": assert rank >= 2 P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) P0 = P.clone() P0[0::2] = 0.0 P1 = P.clone() P1[1::2] = 0.0 P = torch.stack([P0, P1], dim=0) # (2 N) P *= 2 ** ( -0.5 ) # Halve the rank correct just like the original matrix was halved elif measure in ["fourier", "fout"]: P = torch.zeros(N) P[0::2] = 2**0.5 P[0] = 1 P = P.unsqueeze(0) elif measure in ["fourier_diag", "foud", "legsd"]: P = torch.zeros(1, N, dtype=dtype) else: raise NotImplementedError d = P.size(0) if rank > d: P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) return P
[docs]def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): """Decompose as Normal Plus Low-Rank (NPLR). Return w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V i.e. A = V[w - p q^*]V^*, B = V B """ assert dtype == torch.float or torch.double cdtype = torch.cfloat if dtype == torch.float else torch.cdouble A, B = transition(measure, N) A = torch.as_tensor(A, dtype=dtype) # (N, N) B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) # We require AP to be nearly skew-symmetric _A = AP + AP.transpose(-1, -2) # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): err = torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N if err > 1e-5: print("WARNING: HiPPO matrix not skew symmetric", err) # Take advantage of identity + skew-symmetric form # to calculate real and imaginary parts separately # Imaginary part can use eigh instead of eig w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) # Diagonalize in double precision if diagonalize_precision: AP = AP.to(torch.double) w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) w = w_re + 1j * w_im # Check: V w V^{-1} = A # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) # Only keep half of each conjugate pair _, idx = torch.sort(w.imag) w_sorted = w[idx] V_sorted = V[:, idx] # There is an edge case when eigenvalues can be 0, # which requires some machinery to handle # We use a huge hack here: Assume only one pair is 0, # and that it is the first row/column of A (only happens in Fourier case) V = V_sorted[:, : N // 2] w = w_sorted[: N // 2] assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" if w[-1].abs() < 1e-4: V[:, -1] = 0.0 V[0, -1] = 2**-0.5 V[1, -1] = 2**-0.5 * 1j _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) err = torch.sum((2 * _AP.real - AP) ** 2) / N if err > 1e-5: print( "Warning: Diagonalization of A matrix not numerically precise - error", err ) # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) V_inv = V.conj().transpose(-1, -2) B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P return w, P, B, V
[docs]def dplr( scaling, N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False, ): assert dtype == torch.float or torch.double dtype = torch.cfloat if dtype == torch.float else torch.cdouble pi = torch.tensor(math.pi) if random_real: real_part = torch.rand(H, N // 2) else: real_part = 0.5 * torch.ones(H, N // 2) if random_imag: imag_part = N // 2 * torch.rand(H, N // 2) else: imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) real_part = real_scale * real_part if scaling == "random": imag_part = torch.randn(H, N // 2) elif scaling == "real": imag_part = 0 * imag_part real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) elif scaling in ["linear", "lin"]: imag_part = pi * imag_part elif scaling in [ "inverse", "inv", ]: # Based on asymptotics of the default HiPPO matrix imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) elif scaling in ["inverse2", "inv2"]: imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) elif scaling in ["quadratic", "quad"]: imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 elif scaling in ["legs", "hippo"]: w, _, _, _ = nplr("legsd", N) imag_part = w.imag else: raise NotImplementedError imag_part = imag_scale * imag_part w = -real_part + 1j * imag_part # Initialize B if random_B: B = torch.randn(H, N // 2, dtype=dtype) else: B = torch.ones(H, N // 2, dtype=dtype) if normalize: norm = ( -B / w ) # (H, N) # Result if you integrate the kernel with constant 1 function zeta = 2 * torch.sum( torch.abs(norm) ** 2, dim=-1, keepdim=True ) # Variance with a random C vector B = B / zeta**0.5 P = torch.randn(rank, H, N // 2, dtype=dtype) if diagonal: P = P * 0.0 V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing V = repeat(V, "n m -> h n m", h=H) return w, P, B, V
[docs]def ssm(measure, N, R, H, **ssm_args): """Dispatcher to create single SSM initialization. N: state size R: rank (for DPLR parameterization) H: number of independent SSM copies """ if measure == "dplr": w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) elif measure.startswith("diag"): args = measure.split("-") assert args[0] == "diag" and len(args) > 1 scaling = args[1] w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) else: w, P, B, V = nplr(measure, N, R, **ssm_args) w = repeat(w, "n -> s n", s=H) P = repeat(P, "r n -> r s n", s=H) B = repeat(B, "n -> s n", s=H) V = repeat(V, "n m -> s n m", s=H) return w, P, B, V
combinations = { "hippo": ["legs", "fourier"], "diag": ["diag-inv", "diag-lin"], "all": ["legs", "fourier", "diag-inv", "diag-lin"], }
[docs]def combination(measures, N, R, S, **ssm_args): if isinstance(measures, str): measures = combinations[measures] if measures in combinations else [measures] assert S % len(measures) == 0, ( f"{S} independent trainable SSM copies must be multiple of {len(measures)} " "different measures" ) w, P, B, V = zip( *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] ) w = torch.cat(w, dim=0) # (S N) P = torch.cat(P, dim=1) # (R S N) B = torch.cat(B, dim=0) # (S N) V = torch.cat(V, dim=0) # (S N N) return w, P, B, V
[docs]class OptimModule(nn.Module): """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters. # noqa"""
[docs] def register(self, name, tensor, lr=None): """Register a tensor with a configurable learning rate and 0 weight decay.""" if lr == 0.0: self.register_buffer(name, tensor) else: self.register_parameter(name, nn.Parameter(tensor)) optim = {"weight_decay": 0.0} if lr is not None: optim["lr"] = lr setattr(getattr(self, name), "_optim", optim)
[docs]class SSKernelNPLR(OptimModule): """Stores a representation of and computes the SSKernel function. K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) """ @torch.no_grad() def _setup_C(self, L): """Construct C~ from C. Two modes are supported: go directly to length L if self.L is 1, or length is doubled """ if self.L.item() == 0: if self.verbose: log.info(f"S4: Initializing kernel to length {L}") double_length = False elif L > self.L.item(): # 2*int(self.L) == L: if self.verbose: log.info( f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}" ) double_length = True L = self.L.item() # Convenience for the math below else: return C = _r2c(self.C) dA, _ = self._setup_state() dA_L = power(L, dA) # Multiply C by I - dA_L C_ = _conj(C) prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) if double_length: prod = -prod # Multiply by I + dA_L instead C_ = C_ - prod C_ = C_[..., : self.N] # Take conjugate pairs again self.C.copy_(_c2r(C_)) self.L = 2 * self.L if double_length else self.L + L # Preserve type/device def _omega(self, L, dtype, device, cache=True): """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform. # noqa This should be called everytime the internal length self.L changes """ # Use cached if available if cache and hasattr(self, "omega") and self.omega.size(-1) == L // 2 + 1: return self.omega, self.z omega = torch.tensor( np.exp(-2j * np.pi / (L)), dtype=dtype, device=device ) # \omega_{2L} omega = omega ** torch.arange(0, L // 2 + 1, device=device) z = 2 * (1 - omega) / (1 + omega) # Cache if necessary if cache: self.omega = omega self.z = z return omega, z def __init__( self, w, P, B, C, log_dt, L=None, # starting/maximum length of kernel lr=None, verbose=False, keops=False, real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid'] real_tolerance=1e-3, bandlimit=None, ): """Initialize kernel. L: Maximum length; this module computes an SSM kernel of length L A is represented by diag(w) - PP^* w: (S, N) diagonal part P: (R, S, N) low-rank part B: (S, N) C: (C, H, N) dt: (H) timescale per feature lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) Dimensions: N (or d_state): state size H (or d_model): total SSM copies S (or n_ssm): number of trainable copies of (A, B, dt); must divide H R (or rank): rank of low-rank part C (or channels): system is 1-dim to C-dim The forward pass of this Module returns a tensor of shape (C, H, L) Note: tensor shape N here denotes half the true state size, because of conjugate symmetry """ super().__init__() self.verbose = verbose self.keops = keops self.bandlimit = bandlimit self.real_type = real_type self.real_tolerance = real_tolerance # Rank of low-rank correction self.rank = P.shape[-3] assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = w.size(-1) # Check different SSM inits assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm assert self.H % w.size(0) == 0 self.n_ssm = w.size(0) self.broadcast = self.H // w.size( 0 ) # Each trainable SSM needs to be duplicated this many times # Broadcast everything to correct shapes C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) B = B.unsqueeze(0) # (1, 1, N) # Register parameters self.C = nn.Parameter(_c2r(_resolve_conj(C))) if lr is None or isinstance(lr, float): lr_dict = {} else: lr_dict, lr = lr, None self.register("log_dt", log_dt, lr_dict.get("dt", lr)) self.register("B", _c2r(B), lr_dict.get("B", lr)) self.register("P", _c2r(P), lr_dict.get("A", lr)) self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr)) self.register("w_imag", w.imag, lr_dict.get("A", lr)) self.l_max = L self.register_buffer("L", torch.tensor(0)) # Internal length def _w_init(self, w_real): w_real = torch.clamp(w_real, max=-self.real_tolerance) if self.real_type == "none": return -w_real elif self.real_type == "exp": return torch.log(-w_real) # Some of the HiPPO methods have real part 0 elif self.real_type == "relu": return -w_real elif self.real_type == "sigmoid": return torch.logit(-w_real) elif self.real_type == "softplus": return torch.log(torch.exp(-w_real) - 1) else: raise NotImplementedError def _w(self): # Get the internal w (diagonal) parameter if self.real_type == "none": w_real = -self.inv_w_real elif self.real_type == "exp": w_real = -torch.exp(self.inv_w_real) elif self.real_type == "relu": w_real = -F.relu(self.inv_w_real) elif self.real_type == "sigmoid": w_real = -F.sigmoid(self.inv_w_real) elif self.real_type == "softplus": w_real = -F.softplus(self.inv_w_real) else: raise NotImplementedError w = w_real + 1j * self.w_imag return w
[docs] def forward(self, state=None, rate=1.0, L=None): """Forward pass. state: (B, H, N) initial state rate: sampling rate factor L: target length returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state """ # Initialize C~ # if necessary (done in forward pass so it's on the correct device) if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: self._setup_C(self.l_max) # Handle sampling rate logic # The idea is that this kernel's length (in continuous units) is self.L, # while we are asked # to provide a kernel of length L at (relative) frequency rate if L is None: L = round(self.L.item() / rate) # Increase the internal length if needed continuous_L = round(rate * L) while continuous_L > self.L.item(): self._setup_C(continuous_L) discrete_L = round(self.L.item() / rate) dt = torch.exp(self.log_dt) * rate B = _r2c(self.B) C = _r2c(self.C) P = _r2c(self.P) Q = P.conj() w = self._w() # (n_ssm, N) # Address bandlimiting if self.bandlimit is not None: freqs = w.imag.abs() / (2 * math.pi) # (H, N) freqs = dt[:, None] / rate * freqs # (H, N) mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) C = C * mask # Get FFT nodes of right length omega, z = self._omega( discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0) ) # Broadcast parameters to same hidden features H B = repeat(B, "1 t n -> 1 (v t) n", v=self.broadcast) P = repeat(P, "r t n -> r (v t) n", v=self.broadcast) Q = repeat(Q, "r t n -> r (v t) n", v=self.broadcast) w = repeat(w, "t n -> (v t) n", v=self.broadcast) # Augment B if state is not None: # Have to "unbilinear" the state to put it into the same "type" as B # Compute 1/dt * (I + dt/2 A) @ state # Can do this without expanding # (maybe minor speedup using conj symmetry in theory), # but it's easier to read this way s = _conj(state) if state.size(-1) == self.N else state # (B H N) sA = s * _conj(w) - contract( # (B H N) "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) ) s = s / dt.unsqueeze(-1) + sA / 2 s = s[..., : self.N] B = torch.cat([s, B], dim=-3) # (B+1, H, N) # Incorporate dt into A w = w * dt.unsqueeze(-1) # (H N) # Stack B and p, C and q for convenient batching B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) C = torch.cat([C, Q], dim=-3) # (C+R, H, N) # Incorporate B and C batch dimensions v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) # Calculate resolvent at omega if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: r = cauchy_mult(v, z, w, symmetric=True) elif has_pykeops: r = cauchy_conj(v, z, w) else: r = cauchy_naive(v, z, w) r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) # Low-rank Woodbury correction if self.rank == 1: k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( 1 + r[-1:, -1:, :, :] ) elif self.rank == 2: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ :1, 1:, :, : ] * r11[1:, :1, :, :] s = ( r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] ) s = s / det k_f = r00 - s else: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] r11 = rearrange(r11, "a b h n -> h n a b") r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) r11 = rearrange(r11, "h n a b -> a b h n") k_f = r00 - torch.einsum( "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 ) # Final correction for the bilinear transform k_f = k_f * 2 / (1 + omega) # Move from frequency to coefficients k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) # # Truncate to target length k = k[..., :L] if state is not None: k_state = k[:-1, :, :, :] # (B, C, H, L) else: k_state = None k_B = k[-1, :, :, :] # (C H L) return k_B, k_state
@torch.no_grad() def _setup_linear(self): """Create parameters that allow fast linear stepping of state.""" w = self._w() B = _r2c(self.B) # (H N) P = _r2c(self.P) Q = P.conj() # Repeat w shape properly B = repeat(B, "1 t n -> 1 (v t) n", v=self.broadcast) P = repeat(P, "r t n -> r (v t) n", v=self.broadcast) Q = repeat(Q, "r t n -> r (v t) n", v=self.broadcast) w = repeat(w, "t n -> (v t) n", v=self.broadcast) # Prepare Linear stepping dt = torch.exp(self.log_dt) D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) R = ( torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real ) # (H R R) Q_D = rearrange(Q * D, "r h n -> h r n") try: R = torch.linalg.solve(R, Q_D) # (H R N) except: # noqa R = torch.tensor( np.linalg.solve( R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu(), ) ).to(Q_D) R = rearrange(R, "h r n -> r h n") self.step_params = { "D": D, # (H N) "R": R, # (R H N) "P": P, # (R H N) "Q": Q, # (R H N) "B": B, # (1 H N) "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) } def _step_state_linear(self, u=None, state=None): """Step one time step as a recurrent model. Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster u: (H) input state: (H, N/2) state with conjugate pairs Optionally, the state can have last dimension N Returns: same shape as state """ C = _r2c(self.C) # View used for dtype/device if u is None: # Special case used to find dA u = torch.zeros(self.H, dtype=C.dtype, device=C.device) if state is None: # Special case used to find dB state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) step_params = self.step_params.copy() if ( state.size(-1) == self.N ): # Only store half of the conjugate pairs; should be true by default # There should be a slightly faster way using conjugate symmetry def contract_fn(p, x, y): return contract( "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y) )[ ..., : self.N ] # inner outer product else: assert state.size(-1) == 2 * self.N step_params = {k: _conj(v) for k, v in step_params.items()} # Worth setting up a contract_expression in default_state # if we want to use this at inference time for stepping def contract_fn(p, x, y): return contract( "r h n, r h m, ... h m -> ... h n", p, x, y ) # inner outer product D = step_params["D"] # (H N) E = step_params["E"] # (H N) R = step_params["R"] # (R H N) P = step_params["P"] # (R H N) Q = step_params["Q"] # (R H N) B = step_params["B"] # (1 H N) new_state = E * state - contract_fn(P, Q, state) # (B H N) new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) new_state = D * (new_state - contract_fn(P, R, new_state)) return new_state def _setup_state(self): """Construct dA and dB for discretized state equation.""" # Construct dA and dB by using the stepping self._setup_linear() C = _r2c(self.C) # Just returns a view that we use for finding dtype/device state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze( -2 ) # (N 1 N) dA = self._step_state_linear(state=state) dA = rearrange(dA, "n h m -> h m n") u = C.new_ones(self.H) dB = self._step_state_linear(u=u) dB = _conj(dB) dB = rearrange(dB, "1 h n -> h n") # (H N) return dA, dB def _step_state(self, u, state): """Step one time step as a recurrent model. Must be called after self.default_state() is used to construct an initial state! """ next_state = self.state_contraction(self.dA, state) + self.input_contraction( self.dB, u ) return next_state def _setup_step(self, mode="dense"): """Set up dA, dB, dC discretized parameters for stepping.""" self.dA, self.dB = self._setup_state() # Calculate original C C = _conj(_r2c(self.C)) # (H C N) if self.L.item() == 0: dC = C else: # self.C represents C_tilde dA_L = power(self.L.item(), self.dA) E = torch.eye(self.dA.size(-1)).to(dA_L) dC = torch.linalg.solve( E - dA_L.transpose(-1, -2), C.unsqueeze(-1), ).squeeze(-1) self.dC = dC # Do special preprocessing for different step modes self._step_mode = mode if mode == "linear": # Linear case: special step function for the state, we need to handle output # use conjugate symmetry by default, which affects the output projection self.dC = 2 * self.dC[:, :, : self.N] elif mode == "diagonal": # Eigendecomposition of the A matrix L, V = torch.linalg.eig(self.dA) V_inv = torch.linalg.inv(V) # Check that the eigendedecomposition is correct if self.verbose: print( "Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA), ) # Change the parameterization to diagonalize self.dA = L self.dB = contract("h n m, h m -> h n", V_inv, self.dB) self.dC = contract("h n m, c h n -> c h m", V, self.dC) elif mode == "dense": pass else: raise NotImplementedError( "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" )
[docs] def default_state(self, *batch_shape): C = _r2c(self.C) N = C.size(-1) H = C.size(-2) # Cache the tensor contractions we will later do, for efficiency # These are put in this function because they depend on the batch size step_mode = getattr(self, "_step_mode", "dense") # Used in default_state, # which is called without _setup_step() in forward_state() if step_mode != "linear": N *= 2 if step_mode == "diagonal": self.state_contraction = contract_expression( "h n, ... h n -> ... h n", (H, N), batch_shape + (H, N), ) else: # Dense (quadratic) case: expand all terms self.state_contraction = contract_expression( "h m n, ... h n -> ... h m", (H, N, N), batch_shape + (H, N), ) self.input_contraction = contract_expression( "h n, ... h -> ... h n", (H, N), batch_shape + (H,), # self.dB.shape ) self.output_contraction = contract_expression( "c h n, ... h n -> ... c h", (C.shape[0], H, N), # self.dC.shape batch_shape + (H, N), ) state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) return state
[docs] def step(self, u, state): """Step one time step as a recurrent model. Must have called self._setup_step() and created state with self.default_state() before calling this """ if self._step_mode == "linear": new_state = self._step_state_linear(u, state) else: new_state = self._step_state(u, state) y = self.output_contraction(self.dC, new_state) return y.real, new_state
[docs]class SSKernelDiag(OptimModule): """Version using (complex) diagonal state matrix (S4D).""" def __init__( self, A, B, C, log_dt, L=None, disc="bilinear", real_type="exp", lr=None, bandlimit=None, ): super().__init__() self.L = L self.disc = disc self.bandlimit = bandlimit self.real_type = real_type # Rank of low-rank correction assert A.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = A.size(-1) assert A.size(-2) == B.size(-2) # Number of independent SSMs trained assert self.H % A.size(-2) == 0 self.n_ssm = A.size(-2) self.repeat = self.H // A.size(0) self.channels = C.shape[0] self.C = nn.Parameter(_c2r(_resolve_conj(C))) # Register parameters if lr is None or isinstance(lr, float): lr_dict = {} else: lr_dict, lr = lr, None self.register("log_dt", log_dt, lr_dict.get("dt", lr)) self.register("A", _c2r(A), lr_dict.get("A", lr)) self.register("B", _c2r(B), lr_dict.get("B", lr)) self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) self.register("A_imag", A.imag, lr_dict.get("A", lr)) def _A_init(self, A_real): A_real = torch.clamp(A_real, max=-1e-4) if self.real_type == "none": return -A_real elif self.real_type == "exp": return torch.log(-A_real) # Some of the HiPPO methods have real part 0 elif self.real_type == "relu": return -A_real elif self.real_type == "sigmoid": return torch.logit(-A_real) elif self.real_type == "softplus": return torch.log(torch.exp(-A_real) - 1) else: raise NotImplementedError def _A(self): # Get the internal A (diagonal) parameter if self.real_type == "none": A_real = -self.inv_A_real elif self.real_type == "exp": A_real = -torch.exp(self.inv_A_real) elif self.real_type == "relu": # JAX version seems to NaN if you alloA 0's, # although this code Aas fine Aithout it A_real = -F.relu(self.inv_A_real) - 1e-4 elif self.real_type == "sigmoid": A_real = -F.sigmoid(self.inv_A_real) elif self.real_type == "softplus": A_real = -F.softplus(self.inv_A_real) else: raise NotImplementedError A = A_real + 1j * self.A_imag return A
[docs] def forward(self, L, state=None, rate=1.0, u=None): """Forward pass. state: (B, H, N) initial state rate: sampling rate factor L: target length returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state """ dt = torch.exp(self.log_dt) * rate # (H) C = _r2c(self.C) # (C H N) A = self._A() # (H N) B = _r2c(self.B) B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) if self.bandlimit is not None: freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) C = C * mask # Incorporate dt into A A = repeat(A, "t n -> (v t) n", v=self.repeat) dtA = A * dt.unsqueeze(-1) # (H N) # Augment B with state if state is not None: s = state / dt.unsqueeze(-1) if self.disc == "bilinear": s = s * (1.0 + dtA / 2) elif self.disc == "zoh": s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) B = torch.cat([s, B], dim=-3) # (1+B H N) C = (B[:, None, :, :] * C).view(-1, self.H, self.N) if self.disc == "zoh": # Power up C = C * (torch.exp(dtA) - 1.0) / A K = log_vandermonde(C, dtA, L) # (H L) elif self.disc == "bilinear": C = C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) K = log_vandermonde(C, dA.log(), L) elif self.disc == "dss": # Implementation from DSS meant for case # when real eigenvalues can be positive P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] A_gt_0 = A.real > 0 # [N] if A_gt_0.any(): with torch.no_grad(): P_max = dtA * (A_gt_0 * (L - 1)) # [H N] P = P - P_max.unsqueeze(-1) # [H N L] S = P.exp() # [H N L] dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] num = dtA_neg.exp() - 1 # [H N] den = (dtA_neg * L).exp() - 1 # [H N] # Inline reciprocal function for DSS logic x = den * A x_conj = _resolve_conj(x) r = x_conj / (x * x_conj + 1e-7) C = C * num * r # [C H N] K = contract("chn,hnl->chl", C, S).float() else: assert False, f"{self.disc} not supported" K = K.view(-1, self.channels, self.H, L) # (1+B C H L) if state is not None: K_state = K[:-1, :, :, :] # (B C H L) else: K_state = None K = K[-1, :, :, :] # (C H L) return K, K_state
def _setup_step(self): # These methods are organized # like this to be compatible with the NPLR kernel interface dt = torch.exp(self.log_dt) # (H) B = _r2c(self.B) # (H N) C = _r2c(self.C) # (C H N) self.dC = C A = self._A() # (H N) # Incorporate dt into A dtA = A * dt.unsqueeze(-1) # (H N) if self.disc == "zoh": self.dA = torch.exp(dtA) # (H N) self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) elif self.disc == "bilinear": self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) self.dB = ( B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) ) # or * dtA / A
[docs] def default_state(self, *batch_shape): C = _r2c(self.C) state = torch.zeros( *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device ) return state
[docs] def step(self, u, state): next_state = contract("h n, b h n -> b h n", self.dA, state) + contract( "h n, b h -> b h n", self.dB, u ) y = contract("c h n, b h n -> b c h", self.dC, next_state) return 2 * y.real, next_state
[docs] def forward_state(self, u, state): self._setup_step() AL = self.dA ** u.size(-1) u = u.flip(-1).to(self.dA).contiguous() # (B H L) v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) next_state = AL * state + v return next_state
[docs]class SSKernel(nn.Module): """Wrapper around SSKernel parameterizations. The SSKernel is expected to support the interface forward() default_state() _setup_step() step() """ def __init__( self, H, N=64, L=None, measure="legs", rank=1, channels=1, dt_min=0.001, dt_max=0.1, deterministic=False, lr=None, mode="nplr", n_ssm=None, verbose=False, measure_args={}, **kernel_args, ): r"""State Space Kernel which computes the convolution kernel $\\bar{K}$. H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead dt_min, dt_max: min and max values for the step size dt (\Delta) mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. """ super().__init__() self.N = N self.H = H dtype, cdtype = torch.float, torch.cfloat self.channels = channels self.n_ssm = n_ssm if n_ssm is not None else H self.mode = mode self.verbose = verbose self.kernel_args = kernel_args # Generate dt if deterministic: log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) else: log_dt = torch.rand(self.H, dtype=dtype) * ( math.log(dt_max) - math.log(dt_min) ) + math.log(dt_min) # Compute the preprocessed representation w, P, B, V = combination(measure, self.N, rank, self.n_ssm, **measure_args) # Broadcast C to have H channels if deterministic: C = torch.zeros(channels, self.H, self.N, dtype=cdtype) C[:, :, :1] = 1.0 C = contract("hmn, chn -> chm", V.conj().transpose(-1, -2), C) # V^* C else: C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) # Broadcast other parameters to have n_ssm copies assert ( self.n_ssm % B.size(-2) == 0 and self.n_ssm % P.size(-2) == 0 and self.n_ssm % w.size(-2) == 0 ) # Broadcast tensors to n_ssm copies # These will be the parameters, # so make sure tensors are materialized and contiguous B = repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)).clone().contiguous() P = ( repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) .clone() .contiguous() ) w = repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)).clone().contiguous() C = C.contiguous() if mode == "nplr": self.kernel = SSKernelNPLR( w, P, B, C, log_dt, L=L, lr=lr, verbose=verbose, **kernel_args, ) elif mode == "diag": if not measure.startswith("diag"): log.warning( "Diagonal kernel (S4D) activated but initialization is not " "intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or " "'diag-legs' for the main variants, or 'diag' " "for a combination of S4D-Lin and S4D-Inv." ) C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm) self.kernel = SSKernelDiag( w, B, C, log_dt, L=L, lr=lr, **kernel_args, ) else: raise NotImplementedError(f"mode={mode} is not valid")
[docs] def forward(self, state=None, L=None, rate=None): return self.kernel(state=state, L=L, rate=rate)
[docs] @torch.no_grad() def forward_state(self, u, state): """Forward the state through a sequence. i.e. computes the state after passing chunk through SSM state: (B, H, N) u: (B, H, L) Returns: (B, H, N) """ if hasattr(self.kernel, "forward_state"): return self.kernel.forward_state(u, state) dA, dB = self.kernel._setup_state() # Construct dA, dB matrices # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) conj = state.size(-1) != dA.size(-1) if conj: state = _conj(state) v = contract( "h n, b h l -> b h n l", dB, u.flip(-1) ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) AL, v = power(u.size(-1), dA, v) next_state = contract("h m n, b h n -> b h m", AL, state) next_state = next_state + v if conj: next_state = next_state[..., : next_state.size(-1) // 2] return next_state
def _setup_step(self, **kwargs): # This method is intended to be private so that setting up an S4 module with # ``` # if hasattr(module, 'setup_step'): module.setup_step() # ``` # will not trigger this method multiple times self.kernel._setup_step(**kwargs)
[docs] def step(self, u, state, **kwargs): y, state = self.kernel.step(u, state, **kwargs) return y, state
[docs] def default_state(self, *args, **kwargs): return self.kernel.default_state(*args, **kwargs)
[docs]class S4(nn.Module): def __init__( self, d_model, d_state=64, l_max=None, channels=1, bidirectional=False, # Arguments for position-wise feedforward components activation="gelu", postact="glu", hyper_act=None, dropout=0.0, tie_dropout=False, bottleneck=None, gate=None, transposed=True, verbose=False, # SSM Kernel arguments **kernel_args, ): """Initialize S4 module. d_state: the dimension of the state, also denoted by N l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models bidirectional: if True, convolution kernel will be two-sided Position-wise feedforward components: -------------------- activation: activation in between SS and FF postact: activation after FF hyper_act: use a "hypernetwork" multiplication (experimental) dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d Other arguments: -------------------- transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] gate: add gated activation (GSS) bottleneck: reduce SSM dimension (GSS) See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" Other options are all experimental and should not need to be configured """ super().__init__() if verbose: log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})") self.d_model = d_model self.H = d_model self.N = d_state self.L = l_max self.bidirectional = bidirectional self.channels = channels self.transposed = transposed self.gate = gate self.bottleneck = bottleneck if bottleneck is not None: self.H = self.H // bottleneck self.input_linear = LinearActivation( self.d_model, self.H, transposed=not self.transposed, activation=activation, activate=True, ) if gate is not None: self.input_gate = LinearActivation( self.d_model, self.d_model * gate, transposed=not self.transposed, activation=activation, activate=True, ) self.output_gate = LinearActivation( self.d_model * gate, self.d_model, transposed=self.transposed, activation=None, activate=False, ) # optional multiplicative modulation GLU-style # https://arxiv.org/abs/2002.05202 self.hyper = hyper_act is not None if self.hyper: channels *= 2 self.hyper_activation = Activation(hyper_act) self.D = nn.Parameter(torch.randn(channels, self.H)) if self.bidirectional: channels *= 2 # SSM Kernel self.kernel = SSKernel( self.H, N=self.N, L=self.L, channels=channels, verbose=verbose, **kernel_args, ) # Pointwise self.activation = Activation(activation) dropout_fn = DropoutNd if tie_dropout else nn.Dropout self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() # position-wise output transform to mix features self.output_linear = LinearActivation( self.H * self.channels, self.d_model * (1 if self.gate is None else self.gate), transposed=self.transposed, activation=postact, activate=True, )
[docs] def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): """Forward pass. u: (B H L) if self.transposed else (B L H) state: (H N) never needed unless you know what you're doing Returns: same shape as u """ if not self.transposed: u = u.transpose(-1, -2) L = u.size(-1) # Mask out padding tokens if isinstance(lengths, int): if lengths != L: lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) else: lengths = None if lengths is not None: if lengths.ndim == 0: lengths = lengths.unsqueeze(0) assert ( isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] ), print(f"l:{lengths.ndim}, {lengths.size()}, {u.size()}") mask = torch.where( torch.arange(L, device=lengths.device) < lengths[:, None, None], 1.0, 0.0, ) u = u * mask if self.gate is not None: v = self.input_gate(u) if self.bottleneck is not None: u = self.input_linear(u) # Compute SS Kernel L_kernel = L if self.L is None else min(L, round(self.L / rate)) k, k_state = self.kernel( L=L_kernel, rate=rate, state=state ) # (C H L) (B C H L) # Convolution if self.bidirectional: k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L) u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L) y_f = contract("bhl,chl->bchl", u_f, k_f) y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L) # Compute D term in state space equation - essentially a skip connection y = y + contract("bhl,ch->bchl", u, self.D) # Compute state update if state is not None: assert ( not self.bidirectional ), "Bidirectional not supported with state forwarding" y = y + k_state # next_state = self.kernel.forward_state(u, state) else: next_state = None # Optional hyper-network multiplication if self.hyper: y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) y = self.hyper_activation(yh) * y # Reshape to flatten channels y = rearrange(y, "... c h l -> ... (c h) l") y = self.dropout(self.activation(y)) if not self.transposed: y = y.transpose(-1, -2) y = self.output_linear(y) if self.gate is not None: if not self.transposed: v = v.transpose(-1, -2) y = self.output_gate(y * v) return y, next_state
[docs] def setup_step(self, **kwargs): self.kernel._setup_step(**kwargs)
[docs] def step(self, u, state, **kwargs): """Step one time step as a recurrent model. Intended to be used during validation. u: (B H) state: (B H N) Returns: output (B H), state (B H N) """ assert not self.training y, next_state = self.kernel.step(u, state) # (B C H) y = y + u.unsqueeze(-2) * self.D y = rearrange(y, "b c h -> b (c h)") y = self.activation(y) if self.transposed: y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) else: y = self.output_linear(y) return y, next_state
[docs] def default_state(self, *batch_shape, device=None): # kernel is not a SequenceModule so it doesn't need to adhere to same interface # the kernel will know the device of its own parameters return self.kernel.default_state(*batch_shape)
@property def d_output(self): return self.d_model