espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor
Less than 1 minute
espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor
class espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor(sde, score_fn, probability_flow=False)
Bases: Predictor
update_fn(x, t, *args)
One update of the predictor.
- Parameters:
- x – A PyTorch tensor representing the current state
- t – A Pytorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes
- Returns: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise.
Useful for denoising.
- Return type: x