espnet2.enh.loss.criterions.time_domain.SDRLoss
espnet2.enh.loss.criterions.time_domain.SDRLoss
class espnet2.enh.loss.criterions.time_domain.SDRLoss(filter_length=512, use_cg_iter=None, clamp_db=None, zero_mean=True, load_diag=None, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
SDR loss.
filter_length: int : The length of the distortion filter allowed (default: 512
)
use_cg_iter: : If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.
clamp_db: float : clamp the output value in [-clamp_db, clamp_db]
zero_mean: bool : When set to True, the mean of all signals is subtracted prior.
load_diag: : If provided, this small value is added to the diagonal coefficients of the system metrices when solving for the filter coefficients. This can help stabilize the metric in the case where some of the reference signals may sometimes be zero
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(ref: Tensor, est: Tensor) → Tensor
SDR forward.
- Parameters:
- ref – Tensor, (…, n_samples) reference signal
- est – Tensor (…, n_samples) estimated signal
- Returns: (…,) : the SDR loss (negative sdr)
- Return type: loss