espnet2.enh.diffusion.score_based_diffusion.ScoreModel
espnet2.enh.diffusion.score_based_diffusion.ScoreModel
class espnet2.enh.diffusion.score_based_diffusion.ScoreModel(**kwargs)
Bases: AbsDiffusion
Initializes internal Module state, shared by both nn.Module and ScriptModule.
enhance(noisy_specturm, sampler_type='pc', predictor='reverse_diffusion', corrector='ald', N=30, corrector_steps=1, snr=0.5, **kwargs)
Enhance function.
- Parameters:
- noisy_specturm (torch.Tensor) – noisy feature in [Batch, T, F]
- sampler_type (str) – sampler, ‘pc’ for Predictor-Corrector and ‘ode’ for ODE sampler.
- predictor (str) – the name of Predictor. ‘reverse_diffusion’, ‘euler_maruyama’, or ‘none’
- corrector (str) – the name of Corrector. ‘langevin’, ‘ald’ or ‘none’
- N (int) – The number of reverse sampling steps.
- corrector_steps (int) – number of steps in the Corrector.
- snr (float) – The SNR to use for the corrector.
- Returns: enhanced feature in [Batch, T, F]
- Return type: X_Hat (torch.Tensor)
forward(feature_ref, feature_mix)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
get_ode_sampler(y, N=None, minibatch=None, **kwargs)
get_pc_sampler(predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs)
score_fn(x, t, y)