espnet2.gan_tts.jets.length_regulator.GaussianUpsampling
Less than 1 minute
espnet2.gan_tts.jets.length_regulator.GaussianUpsampling
class espnet2.gan_tts.jets.length_regulator.GaussianUpsampling(delta=0.1)
Bases: Module
Gaussian upsampling with fixed temperature as in:
https://arxiv.org/abs/2010.04301
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(hs, ds, h_masks=None, d_masks=None)
Upsample hidden states according to durations.
- Parameters:
- hs (Tensor) – Batched hidden state to be expanded (B, T_text, adim).
- ds (Tensor) – Batched token duration (B, T_text).
- h_masks (Tensor) – Mask tensor (B, T_feats).
- d_masks (Tensor) – Mask tensor (B, T_text).
- Returns: Expanded hidden state (B, T_feat, adim).
- Return type: Tensor