"""Implementations of different types of residual functions."""

import torch
from torch import nn

[docs]class Residual(nn.Module): """Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates". """ def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): # print("ConstantResidual extra kwargs", kwargs) super().__init__() assert (d_input == d_model) or alpha == 0.0 self.i_layer = i_layer self.d_input = d_input self.d_model = d_model self.alpha = alpha self.beta = beta @property def d_output(self): return self.d_model
[docs] def forward(self, x, y, transposed): y = self.beta * y if self.beta != 1.0 else y return self.alpha * x + y if self.alpha else y
[docs]class Affine(Residual): """Residual connection with learnable scalar multipliers on the main branch. scalar: Single scalar multiplier, or one per dimension scale, power: Initialize to scale * layer_num**(-power) """ def __init__(self, *args, scalar=True, gamma=0.0, **kwargs): # print("ConstantResidual extra kwargs", kwargs) super().__init__(*args, **kwargs) self.scalar = scalar self.gamma = gamma c = self.beta * self.i_layer ** (-self.gamma) d = 1 if self.scalar else self.d_input self.affine = nn.Parameter(c * torch.ones(d))
[docs] def forward(self, x, y, transposed): c = self.affine if transposed: c = c.unsqueeze(-1) return self.alpha * x + c * y
[docs]class Feedforward(Residual): def __init__(self, *args): # print("Feedforward extra kwargs", kwargs) super().__init__(*args, alpha=0.0, beta=1.0)
[docs]class Highway(Residual): def __init__(self, *args, scaling_correction=False, elemwise=False): super().__init__(*args) self.scaling_correction = 1.732 if scaling_correction else 1.0 self.elemwise = elemwise self.Wx = nn.Linear(self.d_input, self.d_input) if self.elemwise: self.Wy = nn.Parameter(torch.randn(self.d_input)) else: self.Wy = nn.Linear(self.d_input, self.d_input)
[docs] def forward(self, x, y, transposed=False): if self.elemwise: y = self.Wy * y else: y = self.Wy(y) r = torch.sigmoid(self.Wx(x) + y) z = self.scaling_correction * (1.0 - r) * x + r * y return z
[docs]class DecayResidual(Residual): """Residual connection that can decay the linear combination depending on depth.""" def __init__(self, *args, power=0.5, l2=True): # print("DecayResidual extra kwargs", kwargs) super().__init__(*args) self.power = power self.l2 = l2
[docs] def forward(self, x, y, transposed): beta = self.i_layer ** (-self.power) if self.l2: alpha = (1.0 - beta**2) ** 0.5 else: alpha = 1.0 - beta return alpha * x + beta * y
registry = { "F": Feedforward, "N": Feedforward, "R": Residual, "H": Highway, "D": DecayResidual, "A": Affine, "none": Feedforward, "ff": Feedforward, "feedforward": Feedforward, "residual": Residual, "highway": Highway, "decay": DecayResidual, "affine": Affine, }