espnet2.asr.encoder.beats_encoder.BeatsEncoder
espnet2.asr.encoder.beats_encoder.BeatsEncoder
class espnet2.asr.encoder.beats_encoder.BeatsEncoder(input_size: int, beats_ckpt_path: str | None = None, max_layer: int | None = None, downsampling_rate: int = 1, adapter_config: str = '', use_weighted_representation: bool = False, beats_config: BeatsConfig | None = None, specaug_config: Dict | None = None, add_positional_information: bool = False, max_positions: int | None = None)
Bases: AbsEncoder
BEATs: Audio Pre-Training with Acoustic Tokenizers.
(https://arxiv.org/abs/2212.09058) :param beats_ckpt_path: Path to a pretrained Beats checkpoint. If
beats_config is provided and it does not match the config in the checkpoint, code might throw an error.
- Parameters:
- max_layer – Propagate input through all layers for encoding if None. Otherwise use upto max_layer.
- downsampling_rate – Downsampling rate for the encoder. Applied if > 1.
- adapter_config – Path to a config file for the wav2vec2 adapter.
- use_weighted_representation – Use weighted representations from max_layer if True. Weights are randomly initialized.
- beats_config – BeatsConfig object. If provided, we will try to override the config in the checkpoint. This can be used to change dropouts etc for fine-tuning the model while starting from a pretrained checkpoint.
- specaug_config – Dictionary containing parameters for SpecAugment. If provided, SpecAugment will be applied.
- add_positional_information – Add learned positional embeddings.
- max_positions – Maximum number of positions for positional embeddings. Required if add_positional_information is True.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
extract_features(source: Tensor, padding_mask: Tensor | None = None, max_layer: int | None = None)
Extract features from raw audio.
forward(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor | None]
Wrapper for compatibility with ESPnets’ AbsEncoder Interface. :param xs_pad: (B, T, D) :param ilens: (B,) :param prev_states: None
- Returns: (B, T, D) output_lens: (B,) masks: None
- Return type: audio_representation
forward_padding_mask(features: Tensor, padding_mask: Tensor) → Tensor
Forward padding mask. Featuires: BTC, padding_mask: BT.
output_size() → int
preprocess(source: Tensor) → Tensor
Preprocess raw audio.
reload_pretrained_parameters()
Initialization function for Beats.
This must be called last in the initialization procedure. The initialization occurs in three steps:
- ESPnet initializes all modules.
- This function initializes Beats encoder overriding 1.
- Optionally, if we have the pretrained checkpoint, we load the
weights from the checkpoint overriding 2 and 1.