# Source code for espnet2.enh.diffusion.sdes

"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs.

https://github.com/yang-song/score_sde_pytorch
and
https://github.com/sp-uhh/sgmse
"""

import abc
import warnings

import numpy as np
import torch

[docs]class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""

def __init__(self, N):
"""Construct an SDE.

Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N

@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass

[docs]    @abc.abstractmethod
def sde(self, x, t, *args):
pass

[docs]    @abc.abstractmethod
def marginal_prob(self, x, t, *args):
"""Parameters to determine the marginal distribution of

the SDE, $p_t(x|args)$.
"""
pass

[docs]    @abc.abstractmethod
def prior_sampling(self, shape, *args):
"""Generate one sample from the prior distribution,

$p_T(x|args)$ with shape shape.
"""
pass

[docs]    @abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.

Useful for computing the log-likelihood via probability flow ODE.

Args:
z: latent code
Returns:
log probability density
"""
pass

[docs]    def discretize(self, x, t, *args):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.

Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.

Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to self.T)

Returns:
f, G
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t, *args)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G

[docs]    def reverse(oself, score_model, probability_flow=False):
"""Create the reverse-time SDE/ODE.

Args:
score_model: A function that takes x, t and y and returns the score.
probability_flow: If True, create the reverse-time ODE
used for probability flow sampling.
"""
N = oself.N
T = oself.T
sde_fn = oself.sde
discretize_fn = oself.discretize

# Build the class for reverse-time SDE.
class RSDE(oself.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow

@property
def T(self):
return T

def sde(self, x, t, *args):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
rsde_parts = self.rsde_parts(x, t, *args)
total_drift, diffusion = (
rsde_parts["total_drift"],
rsde_parts["diffusion"],
)

def rsde_parts(self, x, t, *args):
sde_drift, sde_diffusion = sde_fn(x, t, *args)
score = score_model(x, t, *args)
score_drift = (
-sde_diffusion[:, None, None, None] ** 2
* score
* (0.5 if self.probability_flow else 1.0)
)
diffusion = (
torch.zeros_like(sde_diffusion)
if self.probability_flow
else sde_diffusion
)
total_drift = sde_drift + score_drift
return {
"total_drift": total_drift,
"diffusion": diffusion,
"sde_drift": sde_drift,
"sde_diffusion": sde_diffusion,
"score_drift": score_drift,
"score": score,
}

def discretize(self, x, t, *args):
"""Create discretized iteration rules for the reverse

diffusion sampler.
"""
f, G = discretize_fn(x, t, *args)
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * (
0.5 if self.probability_flow else 1.0
)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G

return RSDE()

[docs]    @abc.abstractmethod
def copy(self):
pass

[docs]class OUVESDE(SDE):
def __init__(
self, theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000, **ignored_kwargs
):
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE.

Note that the "steady-state mean" y is not provided at construction,
but must rather be given as an argument to the methods
which require it (e.g., sde or marginal_prob).

dx = -theta (y-x) dt + sigma(t) dw

with

sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))

Args:
theta: stiffness parameter.
sigma_min: smallest sigma.
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(N)
self.theta = theta
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.logsig = np.log(self.sigma_max / self.sigma_min)
self.N = N

[docs]    def copy(self):
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)

@property
def T(self):
return 1

[docs]    def sde(self, x, t, y):
drift = self.theta * (y - x)
# the sqrt(2*logsig) factor is required here so that logsig does not in the end
# affect the perturbation kernel standard deviation. this can be understood
# from solving the integral of [exp(2s) * g(s)^2] from s=0 to t with
# g(t) = sigma(t) as defined here, and seeing that logsig remains in the
# integral solution unless this sqrt(2*logsig) factor is included.
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
diffusion = sigma * np.sqrt(2 * self.logsig)
return drift, diffusion

def _mean(self, x0, t, y):
theta = self.theta
exp_interp = torch.exp(-theta * t)[:, None, None, None]
return exp_interp * x0 + (1 - exp_interp) * y

def _std(self, t):
# This is a full solution to the ODE for P(t) in our derivations,
# after choosing g(s) as in self.sde()
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
# could maybe replace the two torch.exp(... * t) terms here by cached values **t
(
sigma_min**2
* torch.exp(-2 * theta * t)
* (torch.exp(2 * (theta + logsig) * t) - 1)
* logsig
)
/ (theta + logsig)
)

[docs]    def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)

[docs]    def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(
f"Target shape {shape} does not match shape of y {y.shape}!"
"Ignoring target shape."
)
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T

[docs]    def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")

[docs]class OUVPSDE(SDE):
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
"""OUVPSDE class.

!!! SGMSE authors observed instabilities around t=0.2. !!!

Construct an Ornstein-Uhlenbeck Variance Preserving SDE:

dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw

with

beta(t) = beta_min + t(beta_max - beta_min)

Note that the "steady-state mean" y is not provided at construction,
but must rather be given as an argument to the methods which
require it (e.g., sde or marginal_prob).

Args:
beta_min: smallest sigma.
beta_max: largest sigma.
stiffness: stiffness factor of the drift. 1 by default.
N: number of discretization steps
"""
super().__init__(N)
self.beta_min = beta_min
self.beta_max = beta_max
self.stiffness = stiffness
self.N = N

[docs]    def copy(self):
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)

@property
def T(self):
return 1

def _beta(self, t):
return self.beta_min + t * (self.beta_max - self.beta_min)

[docs]    def sde(self, x, t, y):
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
diffusion = torch.sqrt(self._beta(t))
return drift, diffusion

def _mean(self, x0, t, y):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1 - b0) + 2 * b0))[
:, None, None, None
]
return y + x0y_fac * (x0 - y)

def _std(self, t):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
return (1 - torch.exp(-0.5 * s * t * (t * (b1 - b0) + 2 * b0))) / s

[docs]    def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)

[docs]    def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(
f"Target shape {shape} does not match shape of y {y.shape}!"
"Ignoring target shape."
)
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T

[docs]    def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")

"""Broadcasts a over all dimensions of x, except the batch dimension,

which must match.
"""

if len(a.shape) != 1:
a = a.squeeze()
if len(a.shape) != 1:
raise ValueError(
f"Don't know how to batch-broadcast tensor a "
f"with more than one effective dimension (shape {a.shape})"
)

if a.shape[0] != x.shape[0] and a.shape[0] != 1:
raise ValueError(
f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} "
"as the batch dimension is not matching"
)

out = a.view((x.shape[0], *(1 for _ in range(len(x.shape) - 1))))
return out