espnet2.enh.layers.uses2_comp.ATFBlock
About 1 min
espnet2.enh.layers.uses2_comp.ATFBlock
class espnet2.enh.layers.uses2_comp.ATFBlock(input_size, input_resolution=(130, 64), window_size=(10, 8), mlp_ratio=4, qkv_bias=True, qk_scale=None, dropout=0.0, att_dropout=0.0, drop_path=0.0, use_checkpoint=False, rnn_type='lstm', hidden_size=128, att_heads=4, activation='relu', bidirectional=True, norm_type='cLN', ch_mode='att', ch_att_dim=256, eps=1e-05, with_channel_modeling=True)
Bases: Module
Container module for a single Attentive Time-Frequency Block.
- Parameters:
- input_size (int) – dimension of the input feature.
- input_resolution (tuple) – frequency and time dimension of the input feature. Only used for efficient training. Should be close to the actual spectrum size (F, T) of training samples.
- window_size (tuple) – size of the Time-Frequency window in Swin-Transformer.
- mlp_ratio (int) – ratio of the MLP hidden size to embedding size in BasicLayer.
- qkv_bias (bool) – If True, add a learnable bias to query, key, value in BasicLayer.
- qk_scale (float) – Override default qk scale of head_dim ** -0.5 in BasicLayer if set.
- dropout (float) – dropout ratio. Default is 0.
- att_dropout (float) – attention dropout ratio in BasicLayer. Default is 0.
- drop_path (float) – drop-path ratio in BasicLayer. Default is 0.
- use_checkpoint (bool) – whether to use checkpointing to save memory.
- rnn_type (str) – type of the RNN cell in the improved Transformer layer.
- hidden_size (int) – hidden dimension of the RNN cell.
- att_heads (int) – number of attention heads in Transformer.
- dropout – dropout ratio. Default is 0.
- activation (str) – non-linear activation function applied in each block.
- bidirectional (bool) – whether the RNN layers are bidirectional.
- norm_type (str) – normalization type in the improved Transformer layer.
- ch_mode (str) – mode of channel modeling. Select from “att”, “tac”, and “att_tac”.
- ch_att_dim (int) – dimension of the channel attention.
- eps (float) – epsilon for layer normalization.
- with_channel_modeling (bool) – whether to use channel attention.
forward(input, ref_channel=None, mem_size=20)
Forward.
- Parameters:
- input (torch.Tensor) – feature sequence (batch, C, N, freq, time)
- ref_channel (None or int) – index of the reference channel.
- mem_size (int) – length of the memory tokens
- Returns: output sequence (batch, C, N, freq, time)
- Return type: output (torch.Tensor)
freq_path_process(x)
pad_to_window_multiples(input, window_size)
Pad the input feature to multiples of the window size.
- Parameters:
- input (torch.Tensor) – input feature (…, freq, time)
- window_size (tuple) – size of the window (H, W).
- Returns: padded input feature (…, n * H, m * W)
- Return type: output (torch.Tensor)
time_freq_process(x)
time_path_process(x)