Source code for espnet2.uasr.loss.abs_loss

from abc import ABC, abstractmethod

import torch

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class AbsUASRLoss(torch.nn.Module, ABC): """Base class for all Diarization loss modules.""" # the name will be the key that appears in the reporter @property def name(self) -> str: return NotImplementedError
[docs] @abstractmethod def forward( self, ) -> torch.Tensor: # the return tensor should be shape of (batch) raise NotImplementedError