espnet2.train.abs_gan_espnet_model.AbsGANESPnetModel
espnet2.train.abs_gan_espnet_model.AbsGANESPnetModel
class espnet2.train.abs_gan_espnet_model.AbsGANESPnetModel(*args, **kwargs)
Bases: AbsESPnetModel
, Module
, ABC
The common abstract class among each GAN-based task.
“ESPnetModel” is referred to a class which inherits torch.nn.Module, and makes the dnn-models “forward” as its member field, a.k.a delegate pattern. And “forward” must accept the argument “forward_generator” and Return the dict of “loss”, “stats”, “weight”, and “optim_idx”. “optim_idx” for generator must be 0 and that for discriminator must be 1.
Example
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsGANESPnetModel):
... def forward(self, input, input_lengths, forward_generator=True):
... ...
... if forward_generator:
... # return loss for the generator
... # optim idx 0 indicates generator optimizer
... return dict(loss=loss, stats=stats, weight=weight, optim_idx=0)
... else:
... # return loss for the discriminator
... # optim idx 1 indicates discriminator optimizer
... return dict(loss=loss, stats=stats, weight=weight, optim_idx=1)
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
abstract collect_feats(**batch: Tensor) → Dict[str, Tensor]
abstract forward(forward_generator: bool = True, **batch: Tensor) → Dict[str, Tensor | Dict[str, Tensor] | int]
Return the generator loss or the discrimiantor loss.
This method must have an argument “forward_generator” to switch the generator loss calculation and the discrimiantor loss calculation. If forward_generator is true, return the generator loss with optim_idx 0. If forward_generator is false, return the discrimiantor loss with optim_idx 1.
- Parameters:forward_generator (bool) – Whether to return the generator loss or the discrimiantor loss. This must have the default value.
- Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- Return type: Dict[str, Any]