espnet.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell
Less than 1 minute
espnet.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell
class espnet.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell(cell, zoneout_rate=0.1)
Bases: Module
ZoneOut Cell module.
This is a module of zoneout described in Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations. This code is modified from eladhoffer/seq2seq.pytorch.
Examples
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
Initialize zone out cell module.
- Parameters:
- cell (torch.nn.Module) – Pytorch recurrent cell module e.g. torch.nn.Module.LSTMCell.
- zoneout_rate (float , optional) – Probability of zoneout from 0.0 to 1.0.
forward(inputs, hidden)
Calculate forward propagation.
- Parameters:
- inputs (Tensor) – Batch of input tensor (B, input_size).
- hidden (tuple) –
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
- Returns:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
- Return type: tuple