Source code for espnet.nets.pytorch_backend.transformer.initializer

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Parameter initialization."""

import torch

from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm


[docs]def initialize(model, init_type="pytorch"): """Initialize Transformer module. :param torch.nn.Module model: transformer instance :param str init_type: initialization type """ if init_type == "pytorch": return # weight init for p in model.parameters(): if p.dim() > 1: if init_type == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif init_type == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif init_type == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif init_type == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + init_type) # bias init for p in model.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in model.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters()