espnet2.asr.state_spaces.base.SequenceModule
espnet2.asr.state_spaces.base.SequenceModule
class espnet2.asr.state_spaces.base.SequenceModule(*args, **kwargs)
Bases: Module
Abstract sequence model class.
All models must adhere to this interface
A SequenceModule is generally a model that transforms an input of shape (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output)
REQUIRED methods and attributes forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation __init__ should also satisfy the following interface; see SequenceIdentity for an example
def __init__(self, d_model, transposed=False,
**
kwargs)
OPTIONAL methods default_state, step: allows stepping the model recurrently with a hidden state state_to_tensor, d_state: allows decoding from hidden state
Initializes internal Module state, shared by both nn.Module and ScriptModule.
property d_model
Model dimension (generally same as input dimension).
This attribute is required for all SequenceModule instantiations. It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model.
property d_output
Output dimension of model.
This attribute is required for all SequenceModule instantiations. It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model.
property d_state
Return dimension of output of self.state_to_tensor.
default_state(*batch_shape, device=None)
Create initial state for a batch of inputs.
forward(x, state=None, **kwargs)
Forward pass.
A sequence-to-sequence transformation with an optional state.
Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output)
Additionally, it returns a “state” which can be any additional information For example, RNN and SSM layers may return their hidden state, while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well
property state_to_tensor
Return a function mapping a state to a single tensor.
This method should be implemented if one wants to use the hidden state insteadof the output sequence for final prediction. Currently only used with the StateDecoder.
step(x, state=None, **kwargs)
Step the model recurrently for one step of the input sequence.
For example, this should correspond to unrolling an RNN for one step. If the forward pass has signature (B, L, H1) -> (B, L, H2), this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state.