espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator
Less than 1 minute
espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator
class espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator(*, in_channels=1, channels=32, strides=[[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]], chan_mults=[1, 2, 4, 4, 8, 8], n_fft=1024, hop_length=256, win_length=1024, stft_normalized=False, logits_abs=True)
Bases: Module
ComplexSTFT Discriminator used in SoundStream.
Initialize Complex STFT Discriminator used in SoundStream.
Adapted from https://github.com/alibaba-damo-academy/FunCodec.git
- Parameters:
- in_channels (int) – Input channel.
- channels (int) – Output channel.
- strides (List *[*List *(*int , int ) ]) – detailed strides in conv2d modules.
- chan_mults (List *[*int ]) – Channel multiplers.
- n_fft (int) – n_fft in the STFT.
- hop_length (int) – hop_length in the STFT.
- stft_normalized (bool) – whether to normalize the stft output.
- logits_abs (bool) – whether to use the absolute number of output logits.
forward(x)
Calculate forward propagation.
- Parameters:x (Tensor) – Input signal (B, 1, T).
- Returns: List of list of the discriminator output.
- Return type: List[List[Tensor]]
Reference: : Paper: https://arxiv.org/pdf/2107.03312.pdf Implementation: https://github.com/alibaba-damo-academy/FunCodec.git