espnet2.enh.layers.uses2_swin.ResSwinBlock
About 1 min
espnet2.enh.layers.uses2_swin.ResSwinBlock
class espnet2.enh.layers.uses2_swin.ResSwinBlock(input_size, input_resolution=(130, 256), swin_block_depth=(4, 4, 4, 4), window_size=(10, 8), mlp_ratio=2, qkv_bias=True, qk_scale=None, dropout=0.0, att_dropout=0.0, drop_path=0.0, activation='relu', att_heads=4, use_checkpoint=False, ch_mode='att_tac', ch_att_dim=256, eps=1e-05, with_channel_modeling=True)
Bases: Module
Container module for a single Residual Shifted-Window Transformer Block.
- Parameters:
- input_size (int) – dimension of the input feature.
- input_resolution (tuple) – frequency and time dimension of the input feature. Only used for efficient training. Should be close to the actual spectrum size (F, T) of training samples.
- swin_block_depth (Tuple *[*int ]) – depth of each ResSwinBlock.
- window_size (tuple) – size of the Time-Frequency window in Swin-Transformer.
- mlp_ratio (int) – ratio of the MLP hidden size to embedding size in BasicLayer.
- qkv_bias (bool) – If True, add a learnable bias to query, key, value in BasicLayer.
- qk_scale (float) – Override default qk scale of head_dim ** -0.5 in BasicLayer if set.
- dropout (float) – dropout ratio in BasicLayer. Default is 0.
- att_dropout (float) – attention dropout ratio in BasicLayer. Default is 0.
- drop_path (float) – drop-path ratio in BasicLayer. Default is 0.
- activation (str) – non-linear activation function applied in each block.
- att_heads (int) – number of attention heads.
- use_checkpoint (bool) – whether to use checkpointing to save memory.
- ch_mode (str) – mode of channel modeling. Select from “att”, “tac” and “att_tac”.
- ch_att_dim (int) – dimension of the channel attention.
- eps (float) – epsilon for layer normalization.
- with_channel_modeling (bool) – whether to use channel attention.
forward(input, ref_channel=None)
Forward.
- Parameters:
- input (torch.Tensor) – feature sequence (batch, C, N, freq, time)
- ref_channel (None or int) – index of the reference channel.
- Returns: output sequence (batch, C, N, freq, time)
- Return type: output (torch.Tensor)
pad_to_window_multiples(input, window_size)
Pad the input feature to multiples of the window size.
- Parameters:
- input (torch.Tensor) – input feature (batch, C, N, freq, time)
- window_size (tuple) – size of the window (H, W).
- Returns: padded input feature (batch, C, N, n * H, m * W)
- Return type: output (torch.Tensor)