Source code for espnet2.asr.state_spaces.block

# This code is derived from https://github.com/HazyResearch/state-spaces

"""Implements a full residual block around a black box layer.

Configurable options include:
normalization position: prenorm or postnorm
normalization type: batchnorm, layernorm etc.
subsampling/pooling
residual options: feedforward, residual, affine scalars, depth-dependent scaling, etc.
"""

from functools import partial

from torch import nn

import espnet2.asr.state_spaces.utils as utils
from espnet2.asr.state_spaces import registry
from espnet2.asr.state_spaces.base import SequenceModule
from espnet2.asr.state_spaces.components import (
    DropoutNd,
    Normalization,
    StochasticDepth,
)
from espnet2.asr.state_spaces.pool import registry as pool_registry
from espnet2.asr.state_spaces.residual import registry as residual_registry


[docs]class SequenceResidualBlock(SequenceModule): """Residual block wrapper for black box layer. The SequenceResidualBlock class implements a generic (batch, length, d_input) -> (batch, length, d_input) transformation Args: d_input: Input feature dimension i_layer: Layer index, only needs to be passed into certain residuals like Decay dropout: Dropout for black box module tie_dropout: Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d transposed: Transpose inputs so each layer receives (batch, dim, length) layer: Config for black box module residual: Config for residual function norm: Config for normalization layer pool: Config for pooling layer per stage drop_path: Drop ratio for stochastic depth """ def __init__( self, d_input, i_layer=None, prenorm=True, dropout=0.0, tie_dropout=False, transposed=False, layer=None, residual=None, norm=None, pool=None, drop_path=0.0, ): super().__init__() self.i_layer = i_layer self.d_input = d_input # self.layer = utils.instantiate(registry.layer, layer, d_input) if layer is None: layer = {} self.layer = utils.instantiate(registry.layer, layer, d_input) self.prenorm = prenorm self.transposed = transposed # Residual # d_residual is the output dimension after residual if residual is None: self.residual = None self.d_residual = self.layer.d_output else: self.residual = utils.instantiate( residual_registry, residual, i_layer, d_input, self.layer.d_output ) self.d_residual = self.residual.d_output # Normalization d_norm = d_input if self.prenorm else self.d_residual # We don't use config to directly instantiate # since Normalization has some special cases if norm is None: self.norm = None elif isinstance(norm, str): self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) else: self.norm = Normalization(d_norm, transposed=self.transposed, **norm) # Pool self.pool = utils.instantiate( pool_registry, pool, self.d_residual, transposed=self.transposed ) # Dropout dropout_cls = ( partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout ) self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() # Stochastic depth self.drop_path = ( StochasticDepth(drop_path, mode="row") if drop_path > 0.0 else nn.Identity() ) @property def d_output(self): return self.pool.d_output if self.pool is not None else self.d_residual @property def d_state(self): return self.layer.d_state @property def state_to_tensor(self): return self.layer.state_to_tensor
[docs] def default_state(self, *args, **kwargs): return self.layer.default_state(*args, **kwargs)
[docs] def forward(self, x, state=None, **kwargs): y = x # Pre-norm if self.norm is not None and self.prenorm: y = self.norm(y) # Black box layer y, state = self.layer(y, state=state, **kwargs) # Residual if self.residual is not None: y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) # Post-norm if self.norm is not None and not self.prenorm: y = self.norm(y) # Pool if self.pool is not None: y = self.pool(y) return y, state
[docs] def step(self, x, state, **kwargs): y = x # Pre-norm if self.norm is not None and self.prenorm: y = self.norm.step(y) # Black box layer y, state = self.layer.step(y, state, **kwargs) # Residual if self.residual is not None: y = self.residual( x, y, transposed=False ) # NOTE this would not work with concat residual function (catformer) # Post-norm if self.norm is not None and not self.prenorm: y = self.norm.step(y) # Pool if self.pool is not None: y = self.pool(y) return y, state