espnet2.enh.layers.uses2_swin.USES2_Swin
espnet2.enh.layers.uses2_swin.USES2_Swin
class espnet2.enh.layers.uses2_swin.USES2_Swin(input_size, output_size, bottleneck_size=64, num_blocks=3, num_spatial_blocks=2, swin_block_depth=(4, 4, 4, 4), input_resolution=(130, 256), window_size=(10, 8), mlp_ratio=4, qkv_bias=True, qk_scale=None, att_heads=4, dropout=0.0, att_dropout=0.0, drop_path=0.0, activation='relu', use_checkpoint=False, ch_mode='att_tac', ch_att_dim=256, eps=1e-05)
Bases: Module
Unconstrained Speech Enhancement and Separation v2 (USES2-Swin) Network.
Reference: : [1] W. Zhang, J.-w. Jung, and Y. Qian, “Improving Design of Input : Condition Invariant Speech Enhancement,” in Proc. ICASSP, 2024. <br/> [2] W. Zhang, K. Saijo, Z.-Q., Wang, S. Watanabe, and Y. Qian, : “Toward Universal Speech Enhancement for Diverse Input Conditions,” in Proc. ASRU, 2023.
- Parameters:
- input_size (int) – dimension of the input feature.
- output_size (int) – dimension of the output.
- bottleneck_size (int) – dimension of the bottleneck feature. Must be a multiple of att_heads.
- num_blocks (int) – number of ResSwinBlock blocks.
- num_spatial_blocks (int) – number of ResSwinBlock blocks with channel modeling.
- swin_block_depth (Tuple *[*int ]) – depth of each ResSwinBlock.
- input_resolution (tuple) – frequency and time dimension of the input feature. Only used for efficient training. Should be close to the actual spectrum size (F, T) of training samples.
- window_size (tuple) – size of the Time-Frequency window in Swin-Transformer.
- mlp_ratio (int) – ratio of the MLP hidden size to embedding size in BasicLayer.
- qkv_bias (bool) – If True, add a learnable bias to query, key, value in BasicLayer.
- qk_scale (float) – Override default qk scale of head_dim ** -0.5 in BasicLayer if set.
- att_heads (int) – number of attention heads in Transformer.
- dropout (float) – dropout ratio in BasicLayer. Default is 0.
- att_dropout (float) – attention dropout ratio in BasicLayer.
- drop_path (float) – drop-path ratio in BasicLayer.
- activation (str) – non-linear activation function applied in each block.
- use_checkpoint (bool) – whether to use checkpointing to save memory.
- ch_mode (str) – mode of channel modeling. Select from “att”, “tac”, and “att_tac”
- ch_att_dim (int) – dimension of the channel attention.
- eps (float) – epsilon for layer normalization.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(input, ref_channel=None)
USES2-Swin forward.
- Parameters:
- input (torch.Tensor) – input feature (batch, mics, input_size, freq, time)
- ref_channel (None or int) – index of the reference channel.
- Returns: output feature (batch, output_size, freq, time)
- Return type: output (torch.Tensor)