Source code for espnet2.enh.diffusion.score_based_diffusion

# The implementation is based on:
# https://github.com/sp-uhh/sgmse
# Licensed under MIT


import math

import torch

import espnet2.enh.diffusion.sampling as sampling
from espnet2.enh.diffusion.abs_diffusion import AbsDiffusion
from espnet2.enh.diffusion.sdes import OUVESDE, OUVPSDE, SDE
from espnet2.enh.layers.dcunet import DCUNet
from espnet2.enh.layers.ncsnpp import NCSNpp
from espnet2.train.class_choices import ClassChoices

score_choices = ClassChoices(
    name="score_model",
    classes=dict(dcunet=DCUNet, ncsnpp=NCSNpp),
    type_check=torch.nn.Module,
    default=None,
)

sde_choices = ClassChoices(
    name="sde",
    classes=dict(
        ouve=OUVESDE,
        ouvp=OUVPSDE,
    ),
    type_check=SDE,
    default="ouve",
)


[docs]class ScoreModel(AbsDiffusion): def __init__(self, **kwargs): super().__init__() score_model = kwargs["score_model"] # noqa score_model_class = score_choices.get_class(kwargs["score_model"]) self.dnn = score_model_class(**kwargs["score_model_conf"]) self.sde = sde_choices.get_class(kwargs["sde"])(**kwargs["sde_conf"]) self.loss_type = getattr(kwargs, "loss_type", "mse") self.t_eps = getattr(kwargs, "t_eps", 3e-2) def _loss(self, err): if self.loss_type == "mse": losses = torch.square(err.abs()) elif self.loss_type == "mae": losses = err.abs() # taken from reduce_op function: sum over channels and position # and mean over batch dim presumably only important for absolute # loss number, not for gradients loss = torch.mean(0.5 * torch.sum(losses.reshape(losses.shape[0], -1), dim=-1)) return loss
[docs] def get_pc_sampler( self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs ): N = self.sde.N if N is None else N sde = self.sde.copy() sde.N = N kwargs = {"eps": self.t_eps, **kwargs} if minibatch is None: return sampling.get_pc_sampler( predictor_name, corrector_name, sde=sde, score_fn=self.score_fn, y=y, **kwargs ) else: M = y.shape[0] def batched_sampling_fn(): samples, ns = [], [] for i in range(int(math.ceil(M / minibatch))): y_mini = y[i * minibatch : (i + 1) * minibatch] sampler = sampling.get_pc_sampler( predictor_name, corrector_name, sde=sde, score_fn=self.score_fn, y=y_mini, **kwargs ) sample, n = sampler() samples.append(sample) ns.append(n) samples = torch.cat(samples, dim=0) return samples, ns return batched_sampling_fn
[docs] def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs): N = self.sde.N if N is None else N sde = self.sde.copy() sde.N = N kwargs = {"eps": self.t_eps, **kwargs} if minibatch is None: return sampling.get_ode_sampler( sde, self.score_fn, y=y, device=y.device, **kwargs ) else: M = y.shape[0] def batched_sampling_fn(): samples, ns = [], [] for i in range(int(math.ceil(M / minibatch))): y_mini = y[i * minibatch : (i + 1) * minibatch] sampler = sampling.get_ode_sampler( sde, self.score_fn, y=y_mini, device=y_mini.device, **kwargs ) sample, n = sampler() samples.append(sample) ns.append(n) samples = torch.cat(samples, dim=0) return sample, ns return batched_sampling_fn
[docs] def score_fn(self, x, t, y): # Concatenate y as an extra channel dnn_input = torch.cat([x, y], dim=1) # the minus is most likely unimportant here - taken from Song's repo score = -self.dnn(dnn_input, t) return score
[docs] def forward( self, feature_ref, feature_mix, ): # feature_ref: B, T, F # feature_mix: B, T, F x = feature_ref.permute(0, 2, 1).unsqueeze(1) y = feature_mix.permute(0, 2, 1).unsqueeze(1) t = ( torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps ) mean, std = self.sde.marginal_prob(x, t, y) z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5 sigmas = std[:, None, None, None] perturbed_data = mean + sigmas * z score = self.score_fn(perturbed_data, t, y) assert score.shape == x.shape, "Check the output shape of the score_fn." err = score * sigmas + z loss = self._loss(err) return loss
[docs] def enhance( self, noisy_specturm, sampler_type="pc", predictor="reverse_diffusion", corrector="ald", N=30, corrector_steps=1, snr=0.5, **kwargs ): """Enhance function. Args: noisy_specturm (torch.Tensor): noisy feature in [Batch, T, F] sampler_type (str): sampler, 'pc' for Predictor-Corrector and 'ode' for ODE sampler. predictor (str): the name of Predictor. 'reverse_diffusion', 'euler_maruyama', or 'none' corrector (str): the name of Corrector. 'langevin', 'ald' or 'none' N (int): The number of reverse sampling steps. corrector_steps (int) : number of steps in the Corrector. snr (float): The SNR to use for the corrector. Returns: X_Hat (torch.Tensor): enhanced feature in [Batch, T, F] """ Y = noisy_specturm.permute(0, 2, 1).unsqueeze(1) if sampler_type == "pc": sampler = self.get_pc_sampler( predictor, corrector, Y, N=N, corrector_steps=corrector_steps, snr=snr, intermediate=False, **kwargs ) elif sampler_type == "ode": sampler = self.get_ode_sampler(Y, N=N, **kwargs) else: print("{} is not a valid sampler type!".format(sampler_type)) X_Hat, nfe = sampler() X_Hat = X_Hat.squeeze(1).permute(0, 2, 1) return X_Hat