from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch_complex.tensor import ComplexTensor
from espnet2.enh.layers.complex_utils import is_complex, new_complex_like
from espnet2.enh.layers.uses import USES
from espnet2.enh.separator.abs_separator import AbsSeparator
[docs]class USESSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
num_spk: int = 2,
enc_channels: int = 256,
bottleneck_size: int = 64,
num_blocks: int = 6,
num_spatial_blocks: int = 3,
ref_channel: Optional[int] = None,
segment_size: int = 64,
memory_size: int = 20,
memory_types: int = 1,
# Transformer-related arguments
rnn_type: str = "lstm",
bidirectional: bool = True,
hidden_size: int = 128,
att_heads: int = 4,
dropout: float = 0.0,
norm_type: str = "cLN",
activation: str = "relu",
ch_mode: Union[str, List[str]] = "att",
ch_att_dim: int = 256,
eps: float = 1e-5,
additional: dict = {},
):
"""Unconstrained Speech Enhancement and Separation (USES) Network.
Reference:
[1] W. Zhang, K. Saijo, Z.-Q., Wang, S. Watanabe, and Y. Qian,
“Toward Universal Speech Enhancement for Diverse Input Conditions,”
in Proc. ASRU, 2023.
Args:
input_dim (int): input feature dimension.
Not used as the model is independent of the input size.
num_spk (int): number of speakers.
enc_channels (int): feature dimension after the Conv1D encoder.
bottleneck_size (int): dimension of the bottleneck feature.
Must be a multiple of `att_heads`.
num_blocks (int): number of processing blocks.
num_spatial_blocks (int): number of processing blocks with channel modeling.
ref_channel (int): reference channel (used in channel modeling modules).
segment_size (int): number of frames in each non-overlapping segment.
This is used to segment long utterances into smaller chunks for
efficient processing.
memory_size (int): group size of global memory tokens.
The basic use of memory tokens is to store the history information from
previous segments.
The memory tokens are updated by the output of the last block after
processing each segment.
memory_types (int): numbre of memory token groups.
Each group corresponds to a different type of processing, i.e.,
the first group is used for denoising without dereverberation,
the second group is used for denoising with dereverberation,
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
bidirectional (bool): whether the inter-chunk RNN layers are bidirectional.
hidden_size (int): dimension of the hidden state.
att_heads (int): number of attention heads.
dropout (float): dropout ratio. Default is 0.
norm_type: type of normalization to use after each inter- or
intra-chunk NN block.
activation: the nonlinear activation function.
ch_mode: str or list, mode of channel modeling. Select from "att" and "tac".
ch_att_dim (int): dimension of the channel attention.
ref_channel: Optional[int], index of the reference channel.
eps (float): epsilon for layer normalization.
"""
super().__init__()
self._num_spk = num_spk
self.enc_channels = enc_channels
self.ref_channel = ref_channel
# used to project each complex-valued time-frequency bin to an embedding
self.post_encoder = torch.nn.Conv2d(2, enc_channels, (3, 3), padding=(1, 1))
assert bottleneck_size % att_heads == 0, (bottleneck_size, att_heads)
opt = {
"memory_types": memory_types,
}
# arguments in `opt` can be updated at inference time to process different data
opt.update(additional)
self.uses = USES(
enc_channels,
output_size=enc_channels * num_spk,
bottleneck_size=bottleneck_size,
num_blocks=num_blocks,
num_spatial_blocks=num_spatial_blocks,
segment_size=segment_size,
memory_size=memory_size,
**opt,
# Transformer-specific arguments
rnn_type=rnn_type,
bidirectional=bidirectional,
hidden_size=hidden_size,
att_heads=att_heads,
dropout=dropout,
norm_type=norm_type,
activation=activation,
ch_mode=ch_mode,
ch_att_dim=ch_att_dim,
eps=eps,
)
# used to project each embedding back to the complex-valued time-frequency bin
self.pre_decoder = torch.nn.ConvTranspose2d(
enc_channels, 2, (3, 3), padding=(1, 1)
)
[docs] def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): STFT spectrum [B, T, (C,) F (,2)]
B is the batch size
T is the number of time frames
C is the number of microphone channels (optional)
F is the number of frequency bins
2 is real and imaginary parts (optional if input is a complex tensor)
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
"mode": one of ("no_dereverb", "dereverb", "both")
1. "no_dereverb": only use the first memory group for denoising
without dereverberation
2. "dereverb": only use the second memory group for denoising
with dereverberation
3. "both": use both memory groups for denoising with and without
dereverberation
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# B, 2, T, (C,) F
if is_complex(input):
feature = torch.stack([input.real, input.imag], dim=1)
else:
assert input.size(-1) == 2, input.shape
feature = input.moveaxis(-1, 1)
# B, C, 2, F, T
if feature.ndim == 4:
feature = feature.moveaxis(-1, -2).unsqueeze(1)
elif feature.ndim == 5:
feature = feature.permute(0, 3, 1, 4, 2).contiguous()
else:
raise ValueError(f"Invalid input shape: {feature.shape}")
B, C, RI, F, T = feature.shape
feature = feature.reshape(-1, RI, F, T)
feature = self.post_encoder(feature) # B*C, enc_channels, F, T
feature = feature.reshape(B, C, -1, F, T).contiguous()
others = {}
# B, enc_channels * num_spk, F, T
if additional is not None:
mode = additional.get("mode", "no_dereverb")
if mode == "no_dereverb":
processed = self.uses(feature, ref_channel=self.ref_channel)
elif mode == "dereverb":
processed = self.uses(feature, ref_channel=self.ref_channel, mem_idx=1)
elif mode == "both":
# For training with multii-condition data
# 1. denoised output without dereverberation
processed = self.uses(feature, ref_channel=self.ref_channel, mem_idx=0)
# 2. denoised output with dereverberation
processed2 = self.uses(feature, ref_channel=self.ref_channel, mem_idx=1)
processed2 = processed2.reshape(
B * self.num_spk, self.enc_channels, F, T
)
processed2 = self.pre_decoder(processed2)
specs2 = processed2.reshape(B, self.num_spk, 2, F, T).moveaxis(-1, -2)
# B, num_spk, T, F
if not is_complex(input):
for spk in range(specs2.size(1)):
others[f"dereverb{spk + 1}"] = ComplexTensor(
specs2[:, spk, 0], specs2[:, spk, 1]
)
else:
for spk in range(specs2.size(1)):
others[f"dereverb{spk + 1}"] = new_complex_like(
input, (specs2[:, spk, 0], specs2[:, spk, 1])
)
else:
raise ValueError(mode)
else:
mode = ""
processed = self.uses(feature, ref_channel=self.ref_channel)
processed = processed.reshape(B * self.num_spk, self.enc_channels, F, T)
processed = self.pre_decoder(processed)
specs = processed.reshape(B, self.num_spk, 2, F, T).moveaxis(-1, -2)
# B, num_spk, T, F
if not is_complex(input):
specs = list(ComplexTensor(specs[:, :, 0], specs[:, :, 1]).unbind(1))
else:
specs = list(
new_complex_like(input, (specs[:, :, 0], specs[:, :, 1])).unbind(1)
)
return specs, ilens, others
@property
def num_spk(self):
return self._num_spk