from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
[docs]class AbsLossWrapper(torch.nn.Module, ABC):
"""Base class for all Enhancement loss wrapper modules."""
# The weight for the current loss in the multi-task learning.
# The overall training target will be combined as:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
weight = 1.0
) -> Tuple[torch.Tensor, Dict, Dict]: