Source code for espnet.nets.pytorch_backend.fastspeech.duration_calculator

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Tomoki Hayashi
#  Apache 2.0  (

"""Duration calculator related modules."""

import torch

from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
from espnet.nets.pytorch_backend.e2e_tts_transformer import Transformer
from espnet.nets.pytorch_backend.nets_utils import pad_list

[docs]class DurationCalculator(torch.nn.Module): """Duration calculator module for FastSpeech. Todo: * Fix the duplicated calculation of diagonal head decision """ def __init__(self, teacher_model): """Initialize duration calculator module. Args: teacher_model (e2e_tts_transformer.Transformer): Pretrained auto-regressive Transformer. """ super(DurationCalculator, self).__init__() if isinstance(teacher_model, Transformer): self.register_buffer("diag_head_idx", torch.tensor(-1)) elif isinstance(teacher_model, Tacotron2): pass else: raise ValueError( "teacher model should be the instance of " "e2e_tts_transformer.Transformer or e2e_tts_tacotron2.Tacotron2." ) self.teacher_model = teacher_model
[docs] def forward(self, xs, ilens, ys, olens, spembs=None): """Calculate forward propagation. Args: xs (Tensor): Batch of the padded sequences of character ids (B, Tmax). ilens (Tensor): Batch of lengths of each input sequence (B,). ys (Tensor): Batch of the padded sequence of target features (B, Lmax, odim). olens (Tensor): Batch of lengths of each output sequence (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Batch of durations (B, Tmax). """ if isinstance(self.teacher_model, Transformer): att_ws = self._calculate_encoder_decoder_attentions( xs, ilens, ys, olens, spembs=spembs ) # TODO(kan-bayashi): fix this issue # this does not work in multi-gpu case. registered buffer is not saved. if int(self.diag_head_idx) == -1: self._init_diagonal_head(att_ws) att_ws = att_ws[:, self.diag_head_idx] else: # NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2 att_ws = self.teacher_model.calculate_all_attentions( xs, ilens, ys, spembs=spembs, keep_tensor=True ) durations = [ self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens) ] return pad_list(durations, 0)
@staticmethod def _calculate_duration(att_w, ilen, olen): return torch.stack( [att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)] ) def _init_diagonal_head(self, att_ws): diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1).mean(dim=0) # (H * L,) self.register_buffer("diag_head_idx", diagonal_scores.argmax()) def _calculate_encoder_decoder_attentions(self, xs, ilens, ys, olens, spembs=None): att_dict = self.teacher_model.calculate_all_attentions( xs, ilens, ys, olens, spembs=spembs, skip_output=True, keep_tensor=True ) return [att_dict[k] for k in att_dict.keys() if "src_attn" in k], dim=1 ) # (B, H*L, Lmax, Tmax)