#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (

"""Layer normalization module."""

import torch

[docs]class LayerNorm(torch.nn.LayerNorm): """Layer normalization module. Args: nout (int): Output dim size. dim (int): Dimension to be normalized. """ def __init__(self, nout, dim=-1): """Construct an LayerNorm object.""" super(LayerNorm, self).__init__(nout, eps=1e-12) self.dim = dim
[docs] def forward(self, x): """Apply layer normalization. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Normalized tensor. """ if self.dim == -1: return super(LayerNorm, self).forward(x) return ( super(LayerNorm, self) .forward(x.transpose(self.dim, -1)) .transpose(self.dim, -1) )