espnet2.enh.diffusion.sdes.SDE
espnet2.enh.diffusion.sdes.SDE
class espnet2.enh.diffusion.sdes.SDE(N)
Bases: ABC
SDE abstract class. Functions are designed for a mini-batch of inputs.
Construct an SDE.
- Parameters:N – number of discretization time steps.
abstract property T
End time of the SDE.
abstract copy()
discretize(x, t, *args)
Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization.
- Parameters:
- x – a torch tensor
- t – a torch float representing the time step (from 0 to self.T)
- Returns: f, G
abstract marginal_prob(x, t, *args)
Parameters to determine the marginal distribution of
the SDE, $p_t(x|args)$.
abstract prior_logp(z)
Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
- Parameters:z – latent code
- Returns: log probability density
abstract prior_sampling(shape, *args)
Generate one sample from the prior distribution,
$p_T(x|args)$ with shape shape.
reverse(score_model, probability_flow=False)
Create the reverse-time SDE/ODE.
- Parameters:
- score_model – A function that takes x, t and y and returns the score.
- probability_flow – If True, create the reverse-time ODE used for probability flow sampling.
abstract sde(x, t, *args)