Source code for espnet2.enh.diffusion.sdes

"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs.

Taken and adapted from
https://github.com/yang-song/score_sde_pytorch
and
https://github.com/sp-uhh/sgmse
"""

import abc
import warnings

import numpy as np
import torch


[docs]class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" def __init__(self, N): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.N = N @property @abc.abstractmethod def T(self): """End time of the SDE.""" pass
[docs] @abc.abstractmethod def sde(self, x, t, *args): pass
[docs] @abc.abstractmethod def marginal_prob(self, x, t, *args): """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$. """ pass
[docs] @abc.abstractmethod def prior_sampling(self, shape, *args): """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`. """ pass
[docs] @abc.abstractmethod def prior_logp(self, z): """Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE. Args: z: latent code Returns: log probability density """ pass
[docs] def discretize(self, x, t, *args): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization. Args: x: a torch tensor t: a torch float representing the time step (from 0 to `self.T`) Returns: f, G """ dt = 1 / self.N drift, diffusion = self.sde(x, t, *args) f = drift * dt G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) return f, G
[docs] def reverse(oself, score_model, probability_flow=False): """Create the reverse-time SDE/ODE. Args: score_model: A function that takes x, t and y and returns the score. probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. """ N = oself.N T = oself.T sde_fn = oself.sde discretize_fn = oself.discretize # Build the class for reverse-time SDE. class RSDE(oself.__class__): def __init__(self): self.N = N self.probability_flow = probability_flow @property def T(self): return T def sde(self, x, t, *args): """Create the drift and diffusion functions for the reverse SDE/ODE.""" rsde_parts = self.rsde_parts(x, t, *args) total_drift, diffusion = ( rsde_parts["total_drift"], rsde_parts["diffusion"], ) return total_drift, diffusion def rsde_parts(self, x, t, *args): sde_drift, sde_diffusion = sde_fn(x, t, *args) score = score_model(x, t, *args) score_drift = ( -sde_diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.0) ) diffusion = ( torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion ) total_drift = sde_drift + score_drift return { "total_drift": total_drift, "diffusion": diffusion, "sde_drift": sde_drift, "sde_diffusion": sde_diffusion, "score_drift": score_drift, "score": score, } def discretize(self, x, t, *args): """Create discretized iteration rules for the reverse diffusion sampler. """ f, G = discretize_fn(x, t, *args) rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * ( 0.5 if self.probability_flow else 1.0 ) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G return RSDE()
[docs] @abc.abstractmethod def copy(self): pass
[docs]class OUVESDE(SDE): def __init__( self, theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000, **ignored_kwargs ): """Construct an Ornstein-Uhlenbeck Variance Exploding SDE. Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument to the methods which require it (e.g., `sde` or `marginal_prob`). dx = -theta (y-x) dt + sigma(t) dw with sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min)) Args: theta: stiffness parameter. sigma_min: smallest sigma. sigma_max: largest sigma. N: number of discretization steps """ super().__init__(N) self.theta = theta self.sigma_min = sigma_min self.sigma_max = sigma_max self.logsig = np.log(self.sigma_max / self.sigma_min) self.N = N
[docs] def copy(self): return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
@property def T(self): return 1
[docs] def sde(self, x, t, y): drift = self.theta * (y - x) # the sqrt(2*logsig) factor is required here so that logsig does not in the end # affect the perturbation kernel standard deviation. this can be understood # from solving the integral of [exp(2s) * g(s)^2] from s=0 to t with # g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the # integral solution unless this sqrt(2*logsig) factor is included. sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t diffusion = sigma * np.sqrt(2 * self.logsig) return drift, diffusion
def _mean(self, x0, t, y): theta = self.theta exp_interp = torch.exp(-theta * t)[:, None, None, None] return exp_interp * x0 + (1 - exp_interp) * y def _std(self, t): # This is a full solution to the ODE for P(t) in our derivations, # after choosing g(s) as in self.sde() sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig # could maybe replace the two torch.exp(... * t) terms here by cached values **t return torch.sqrt( ( sigma_min**2 * torch.exp(-2 * theta * t) * (torch.exp(2 * (theta + logsig) * t) - 1) * logsig ) / (theta + logsig) )
[docs] def marginal_prob(self, x0, t, y): return self._mean(x0, t, y), self._std(t)
[docs] def prior_sampling(self, shape, y): if shape != y.shape: warnings.warn( f"Target shape {shape} does not match shape of y {y.shape}!" "Ignoring target shape." ) std = self._std(torch.ones((y.shape[0],), device=y.device)) x_T = y + torch.randn_like(y) * std[:, None, None, None] return x_T
[docs] def prior_logp(self, z): raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
[docs]class OUVPSDE(SDE): def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs): """OUVPSDE class. !!! SGMSE authors observed instabilities around t=0.2. !!! Construct an Ornstein-Uhlenbeck Variance Preserving SDE: dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw with beta(t) = beta_min + t(beta_max - beta_min) Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument to the methods which require it (e.g., `sde` or `marginal_prob`). Args: beta_min: smallest sigma. beta_max: largest sigma. stiffness: stiffness factor of the drift. 1 by default. N: number of discretization steps """ super().__init__(N) self.beta_min = beta_min self.beta_max = beta_max self.stiffness = stiffness self.N = N
[docs] def copy(self): return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
@property def T(self): return 1 def _beta(self, t): return self.beta_min + t * (self.beta_max - self.beta_min)
[docs] def sde(self, x, t, y): drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x) diffusion = torch.sqrt(self._beta(t)) return drift, diffusion
def _mean(self, x0, t, y): b0, b1, s = self.beta_min, self.beta_max, self.stiffness x0y_fac = torch.exp(-0.25 * s * t * (t * (b1 - b0) + 2 * b0))[ :, None, None, None ] return y + x0y_fac * (x0 - y) def _std(self, t): b0, b1, s = self.beta_min, self.beta_max, self.stiffness return (1 - torch.exp(-0.5 * s * t * (t * (b1 - b0) + 2 * b0))) / s
[docs] def marginal_prob(self, x0, t, y): return self._mean(x0, t, y), self._std(t)
[docs] def prior_sampling(self, shape, y): if shape != y.shape: warnings.warn( f"Target shape {shape} does not match shape of y {y.shape}!" "Ignoring target shape." ) std = self._std(torch.ones((y.shape[0],), device=y.device)) x_T = y + torch.randn_like(y) * std[:, None, None, None] return x_T
[docs] def prior_logp(self, z): raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
[docs]def batch_broadcast(a, x): """Broadcasts a over all dimensions of x, except the batch dimension, which must match. """ if len(a.shape) != 1: a = a.squeeze() if len(a.shape) != 1: raise ValueError( f"Don't know how to batch-broadcast tensor `a` " f"with more than one effective dimension (shape {a.shape})" ) if a.shape[0] != x.shape[0] and a.shape[0] != 1: raise ValueError( f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} " "as the batch dimension is not matching" ) out = a.view((x.shape[0], *(1 for _ in range(len(x.shape) - 1)))) return out