# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""Layers for defining NCSN++.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet2.enh.layers.ncsnpp_utils import layers, up_or_down_sampling
conv1x1 = layers.ddpm_conv1x1
conv3x3 = layers.ddpm_conv3x3
NIN = layers.NIN
default_init = layers.default_init
[docs]class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
[docs] def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
[docs]class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
self.Conv_0 = conv1x1(dim1, dim2)
self.method = method
[docs] def forward(self, x, y):
h = self.Conv_0(x)
if self.method == "cat":
return torch.cat([h, y], dim=1)
elif self.method == "sum":
return h + y
else:
raise ValueError(f"Method {self.method} not recognized.")
[docs]class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(
num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6
)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
[docs] def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
[docs]class Upsample(nn.Module):
def __init__(
self,
in_ch=None,
out_ch=None,
with_conv=False,
fir=False,
fir_kernel=(1, 3, 3, 1),
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch)
else:
if with_conv:
self.Conv2d_0 = up_or_down_sampling.Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.with_conv = with_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
[docs] def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
h = F.interpolate(x, (H * 2, W * 2), "nearest")
if self.with_conv:
h = self.Conv_0(h)
else:
if not self.with_conv:
h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
else:
h = self.Conv2d_0(x)
return h
[docs]class Downsample(nn.Module):
def __init__(
self,
in_ch=None,
out_ch=None,
with_conv=False,
fir=False,
fir_kernel=(1, 3, 3, 1),
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv:
self.Conv2d_0 = up_or_down_sampling.Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
[docs] def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool2d(x, 2, stride=2)
else:
if not self.with_conv:
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
else:
x = self.Conv2d_0(x)
return x
[docs]class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(
num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6
)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(
num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6
)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
[docs] def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
[docs]class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(
num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6
)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(
num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6
)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
[docs] def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
if self.fir:
h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
else:
h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
elif self.down:
if self.fir:
h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
else:
h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)