espnet2.enh.diffusion.sampling.correctors.LangevinCorrector
Less than 1 minute
espnet2.enh.diffusion.sampling.correctors.LangevinCorrector
class espnet2.enh.diffusion.sampling.correctors.LangevinCorrector(sde, score_fn, snr, n_steps)
Bases: Corrector
update_fn(x, t, *args)
One update of the corrector.
- Parameters:
- x – A PyTorch tensor representing the current state
- t – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes
- Returns: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise.
Useful for denoising.
- Return type: x