Source code for espnet2.tts2.feats_extract.identity

from typing import Any, Dict, Optional, Tuple, Union

import torch
from typeguard import typechecked

from espnet2.tts2.feats_extract.abs_feats_extract import AbsFeatsExtractDiscrete

[docs]class IdentityFeatureExtract(AbsFeatsExtractDiscrete): """Keep the input discrete sequence as-is""" @typechecked def __init__(self): super().__init__()
[docs] def forward( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[Any, Dict]: # torch doesn't have .is_int() function assert ( not input.is_complex() and not input.is_floating_point() and not input.dtype == torch.bool ), "Invalid data type." assert input.dim() == 2, "Input should have 2 dimensions." assert input.size(0) == input_lengths.size(0), "Invalid lengths." return input.long(), input_lengths