espnet2.enh.layers.swin_transformer.WindowAttention
Less than 1 minute
espnet2.enh.layers.swin_transformer.WindowAttention
class espnet2.enh.layers.swin_transformer.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0)
Bases: Module
Window-based multi-head self-attention (W-MSA) with relative position bias.
It supports both of shifted and non-shifted windows.
- Parameters:
- dim (int) – Number of input channels.
- window_size (tuple *[*int ]) – The height and width of the window.
- num_heads (int) – Number of attention heads.
- qkv_bias (bool , optional) – If True, add a learnable bias to query, key, value.
- qk_scale (float | None , optional) – If not None, override the default qk scale
- attn_drop (float , optional) – Dropout ratio of attention weight.
- proj_drop (float , optional) – Dropout ratio of output.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
extra_repr() → str
Set the extra representation of the module
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
forward(x, mask=None)
WindowAttention Forward.
- Parameters:
- x – input features with shape of (num_windows*B, N, C)
- mask – (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
get_relative_position_index(H, W)