#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Repeat the same layer definition."""
import torch
[docs]class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def __init__(self, *args, layer_drop_rate=0.0):
"""Initialize MultiSequential with layer_drop.
Args:
layer_drop_rate (float): Probability of dropping out each fn (layer).
"""
super(MultiSequential, self).__init__(*args)
self.layer_drop_rate = layer_drop_rate
[docs] def forward(self, *args):
"""Repeat."""
_probs = torch.empty(len(self)).uniform_()
for idx, m in enumerate(self):
if not self.training or (_probs[idx] >= self.layer_drop_rate):
args = m(*args)
return args
[docs]def repeat(N, fn, layer_drop_rate=0.0):
"""Repeat module N times.
Args:
N (int): Number of repeat time.
fn (Callable): Function to generate module.
layer_drop_rate (float): Probability of dropping out each fn (layer).
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)