Source code for espnet2.train.abs_gan_espnet_model

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (

"""ESPnetModel abstract class for GAN-based training."""

from abc import ABC, abstractmethod
from typing import Dict, Union

import torch

from espnet2.train.abs_espnet_model import AbsESPnetModel

[docs]class AbsGANESPnetModel(AbsESPnetModel, torch.nn.Module, ABC): """The common abstract class among each GAN-based task. "ESPnetModel" is referred to a class which inherits torch.nn.Module, and makes the dnn-models "forward" as its member field, a.k.a delegate pattern. And "forward" must accept the argument "forward_generator" and Return the dict of "loss", "stats", "weight", and "optim_idx". "optim_idx" for generator must be 0 and that for discriminator must be 1. Example: >>> from espnet2.tasks.abs_task import AbsTask >>> class YourESPnetModel(AbsGANESPnetModel): ... def forward(self, input, input_lengths, forward_generator=True): ... ... ... if forward_generator: ... # return loss for the generator ... # optim idx 0 indicates generator optimizer ... return dict(loss=loss, stats=stats, weight=weight, optim_idx=0) ... else: ... # return loss for the discriminator ... # optim idx 1 indicates discriminator optimizer ... return dict(loss=loss, stats=stats, weight=weight, optim_idx=1) >>> class YourTask(AbsTask): ... @classmethod ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel: """
[docs] @abstractmethod def forward( self, forward_generator: bool = True, **batch: torch.Tensor, ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor], int]]: """Return the generator loss or the discrimiantor loss. This method must have an argument "forward_generator" to switch the generator loss calculation and the discrimiantor loss calculation. If forward_generator is true, return the generator loss with optim_idx 0. If forward_generator is false, return the discrimiantor loss with optim_idx 1. Args: forward_generator (bool): Whether to return the generator loss or the discrimiantor loss. This must have the default value. Returns: Dict[str, Any]: * loss (Tensor): Loss scalar tensor. * stats (Dict[str, float]): Statistics to be monitored. * weight (Tensor): Weight tensor to summarize losses. * optim_idx (int): Optimizer index (0 for G and 1 for D). """ raise NotImplementedError
[docs] @abstractmethod def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: raise NotImplementedError