espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper
espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper
class espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper(module_cls, *args, **kwargs)
Bases: Module
Adapted from asteroid’s complex_nn.py, allowing
args/kwargs to be passed through forward().
Make a complex-valued module F from a real-valued module f by applying complex multiplication rules:
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
where f1, f2 are instances of f that do not share weights.
- Parameters:module_cls (callable) – A class or function that returns a Torch module/functional. Constructor of f in the formula above. Called 2x with *args, **kwargs, to construct the real and imaginary component modules.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(x, *args, **kwargs)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.