#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (

"""Repeat the same layer definition."""

import torch

[docs]class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential.""" def __init__(self, *args, layer_drop_rate=0.0): """Initialize MultiSequential with layer_drop. Args: layer_drop_rate (float): Probability of dropping out each fn (layer). """ super(MultiSequential, self).__init__(*args) self.layer_drop_rate = layer_drop_rate
[docs] def forward(self, *args): """Repeat.""" _probs = torch.empty(len(self)).uniform_() for idx, m in enumerate(self): if not or (_probs[idx] >= self.layer_drop_rate): args = m(*args) return args
[docs]def repeat(N, fn, layer_drop_rate=0.0): """Repeat module N times. Args: N (int): Number of repeat time. fn (Callable): Function to generate module. layer_drop_rate (float): Probability of dropping out each fn (layer). Returns: MultiSequential: Repeated model instance. """ return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)