Source code for espnet2.enh.layers.skim

# An implementation of SkiM model described in
# "SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation"
# (

import torch
import torch.nn as nn

from espnet2.enh.layers.dprnn import SingleRNN, merge_feature, split_feature
from espnet2.enh.layers.tcn import choose_norm

[docs]class MemLSTM(nn.Module): """the Mem-LSTM of SkiM args: hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. mem_type: 'hc', 'h', 'c' or 'id'. It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. In 'id' mode, both the hidden and cell states will be identically returned. norm_type: gLN, cLN. cLN is for causal implementation. """ def __init__( self, hidden_size, dropout=0.0, bidirectional=False, mem_type="hc", norm_type="cLN", ): super().__init__() self.hidden_size = hidden_size self.bidirectional = bidirectional self.input_size = (int(bidirectional) + 1) * hidden_size self.mem_type = mem_type assert mem_type in [ "hc", "h", "c", "id", ], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}" if mem_type in ["hc", "h"]: self.h_net = SingleRNN( "LSTM", input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional, ) self.h_norm = choose_norm( norm_type=norm_type, channel_size=self.input_size, shape="BTD" ) if mem_type in ["hc", "c"]: self.c_net = SingleRNN( "LSTM", input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional, ) self.c_norm = choose_norm( norm_type=norm_type, channel_size=self.input_size, shape="BTD" )
[docs] def extra_repr(self) -> str: return f"Mem_type: {self.mem_type}, bidirectional: {self.bidirectional}"
[docs] def forward(self, hc, S): # hc = (h, c), tuple of hidden and cell states from SegLSTM # shape of h and c: (d, B*S, H) # S: number of segments in SegLSTM if self.mem_type == "id": ret_val = hc h, c = hc d, BS, H = h.shape B = BS // S else: h, c = hc d, BS, H = h.shape B = BS // S h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH if self.mem_type == "hc": h = h + self.h_norm(self.h_net(h)[0]) c = c + self.c_norm(self.c_net(c)[0]) elif self.mem_type == "h": h = h + self.h_norm(self.h_net(h)[0]) c = torch.zeros_like(c) elif self.mem_type == "c": h = torch.zeros_like(h) c = c + self.c_norm(self.c_net(c)[0]) h = h.view(B * S, d, H).transpose(1, 0).contiguous() c = c.view(B * S, d, H).transpose(1, 0).contiguous() ret_val = (h, c) if not self.bidirectional: # for causal setup causal_ret_val = [] for x in ret_val: x = x.transpose(1, 0).contiguous().view(B, S, d * H) x_ = torch.zeros_like(x) x_[:, 1:, :] = x[:, :-1, :] x_ = x_.view(B * S, d, H).transpose(1, 0).contiguous() causal_ret_val.append(x_) ret_val = tuple(causal_ret_val) return ret_val
[docs] def forward_one_step(self, hc, state): if self.mem_type == "id": pass else: h, c = hc d, B, H = h.shape h = h.transpose(1, 0).contiguous().view(B, 1, d * H) # B, 1, dH c = c.transpose(1, 0).contiguous().view(B, 1, d * H) # B, 1, dH if self.mem_type == "hc": h_tmp, state[0] = self.h_net(h, state[0]) h = h + self.h_norm(h_tmp) c_tmp, state[1] = self.c_net(c, state[1]) c = c + self.c_norm(c_tmp) elif self.mem_type == "h": h_tmp, state[0] = self.h_net(h, state[0]) h = h + self.h_norm(h_tmp) c = torch.zeros_like(c) elif self.mem_type == "c": h = torch.zeros_like(h) c_tmp, state[1] = self.c_net(c, state[1]) c = c + self.c_norm(c_tmp) h = h.transpose(1, 0).contiguous() c = c.transpose(1, 0).contiguous() hc = (h, c) return hc, state
[docs]class SegLSTM(nn.Module): """the Seg-LSTM of SkiM args: input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. norm_type: gLN, cLN. cLN is for causal implementation. """ def __init__( self, input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type="cLN" ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_direction = int(bidirectional) + 1 self.lstm = nn.LSTM( input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional, ) self.dropout = nn.Dropout(p=dropout) self.proj = nn.Linear(hidden_size * self.num_direction, input_size) self.norm = choose_norm( norm_type=norm_type, channel_size=input_size, shape="BTD" )
[docs] def forward(self, input, hc): # input shape: B, T, H B, T, H = input.shape if hc is None: # In fist input SkiM block, h and c are not available d = self.num_direction h = torch.zeros(d, B, self.hidden_size, dtype=input.dtype).to(input.device) c = torch.zeros(d, B, self.hidden_size, dtype=input.dtype).to(input.device) else: h, c = hc output, (h, c) = self.lstm(input, (h, c)) output = self.dropout(output) output = self.proj(output.contiguous().view(-1, output.shape[2])).view( input.shape ) output = input + self.norm(output) return output, (h, c)
[docs]class SkiM(nn.Module): """Skipping Memory Net args: input_size: int, dimension of the input feature. Input shape shoud be (batch, length, input_size) hidden_size: int, dimension of the hidden state. output_size: int, dimension of the output size. dropout: float, dropout ratio. Default is 0. num_blocks: number of basic SkiM blocks segment_size: segmentation size for splitting long features bidirectional: bool, whether the RNN layers are bidirectional. mem_type: 'hc', 'h', 'c', 'id' or None. It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. In 'id' mode, both the hidden and cell states will be identically returned. When mem_type is None, the MemLSTM will be removed. norm_type: gLN, cLN. cLN is for causal implementation. seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments.Default is False. """ def __init__( self, input_size, hidden_size, output_size, dropout=0.0, num_blocks=2, segment_size=20, bidirectional=True, mem_type="hc", norm_type="gLN", seg_overlap=False, ): super().__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size self.segment_size = segment_size self.dropout = dropout self.num_blocks = num_blocks self.mem_type = mem_type self.norm_type = norm_type self.seg_overlap = seg_overlap assert mem_type in [ "hc", "h", "c", "id", None, ], f"only support 'hc', 'h', 'c', 'id', and None, current type: {mem_type}" self.seg_lstms = nn.ModuleList([]) for i in range(num_blocks): self.seg_lstms.append( SegLSTM( input_size=input_size, hidden_size=hidden_size, dropout=dropout, bidirectional=bidirectional, norm_type=norm_type, ) ) if self.mem_type is not None: self.mem_lstms = nn.ModuleList([]) for i in range(num_blocks - 1): self.mem_lstms.append( MemLSTM( hidden_size, dropout=dropout, bidirectional=bidirectional, mem_type=mem_type, norm_type=norm_type, ) ) self.output_fc = nn.Sequential( nn.PReLU(), nn.Conv1d(input_size, output_size, 1) )
[docs] def forward(self, input): # input shape: B, T (S*K), D B, T, D = input.shape if self.seg_overlap: input, rest = split_feature( input.transpose(1, 2), segment_size=self.segment_size ) # B, D, K, S input = input.permute(0, 3, 2, 1).contiguous() # B, S, K, D else: input, rest = self._padfeature(input=input) input = input.view(B, -1, self.segment_size, D) # B, S, K, D B, S, K, D = input.shape assert K == self.segment_size output = input.view(B * S, K, D).contiguous() # BS, K, D hc = None for i in range(self.num_blocks): output, hc = self.seg_lstms[i](output, hc) # BS, K, D if self.mem_type and i < self.num_blocks - 1: hc = self.mem_lstms[i](hc, S) pass if self.seg_overlap: output = output.view(B, S, K, D).permute(0, 3, 2, 1) # B, D, K, S output = merge_feature(output, rest) # B, D, T output = self.output_fc(output).transpose(1, 2) else: output = output.view(B, S * K, D)[:, :T, :] # B, T, D output = self.output_fc(output.transpose(1, 2)).transpose(1, 2) return output
def _padfeature(self, input): B, T, D = input.shape rest = self.segment_size - T % self.segment_size if rest > 0: input = torch.nn.functional.pad(input, (0, 0, 0, rest)) return input, rest
[docs] def forward_stream(self, input_frame, states): # input_frame # B, 1, N B, _, N = input_frame.shape def empty_seg_states(): shp = (1, B, self.hidden_size) return ( torch.zeros(*shp, device=input_frame.device, dtype=input_frame.dtype), torch.zeros(*shp, device=input_frame.device, dtype=input_frame.dtype), ) B, _, N = input_frame.shape if not states: states = { "current_step": 0, "seg_state": [empty_seg_states() for i in range(self.num_blocks)], "mem_state": [[None, None] for i in range(self.num_blocks - 1)], } output = input_frame if states["current_step"] and (states["current_step"]) % self.segment_size == 0: tmp_states = [empty_seg_states() for i in range(self.num_blocks)] for i in range(self.num_blocks - 1): tmp_states[i + 1], states["mem_state"][i] = self.mem_lstms[ i ].forward_one_step(states["seg_state"][i], states["mem_state"][i]) states["seg_state"] = tmp_states for i in range(self.num_blocks): output, states["seg_state"][i] = self.seg_lstms[i]( output, states["seg_state"][i] ) states["current_step"] += 1 output = self.output_fc(output.transpose(1, 2)).transpose(1, 2) return output, states
if __name__ == "__main__": torch.manual_seed(111) seq_len = 100 model = SkiM( 16, 11, 16, dropout=0.0, num_blocks=4, segment_size=20, bidirectional=False, mem_type="hc", norm_type="cLN", seg_overlap=False, ) model.eval() input = torch.randn(3, seq_len, 16) seg_output = model(input) state = None for i in range(seq_len): input_frame = input[:, i : i + 1, :] output, state = model.forward_stream(input_frame=input_frame, states=state) torch.testing.assert_allclose(output, seg_output[:, i : i + 1, :]) print("streaming ok")