espnet2.gan_tts.jets.alignments.AlignmentModule
Less than 1 minute
espnet2.gan_tts.jets.alignments.AlignmentModule
class espnet2.gan_tts.jets.alignments.AlignmentModule(adim, odim, cache_prior=True)
Bases: Module
Alignment Learning Framework proposed for parallel TTS models in:
https://arxiv.org/abs/2108.10447
Initialize AlignmentModule.
- Parameters:- adim (int) – Dimension of attention.
- odim (int) – Dimension of feats.
- cache_prior (bool) – Whether to cache beta-binomial prior.
 
forward(text, feats, text_lengths, feats_lengths, x_masks=None)
Calculate alignment loss.
- Parameters:- text (Tensor) – Batched text embedding (B, T_text, adim).
- feats (Tensor) – Batched acoustic feature (B, T_feats, odim).
- text_lengths (Tensor) – Text length tensor (B,).
- feats_lengths (Tensor) – Feature length tensor (B,).
- x_masks (Tensor) – Mask tensor (B, T_text).
 
- Returns: Log probability of attention matrix (B, T_feats, T_text).
- Return type: Tensor
