espnet2.cls.lightning_callbacks.MultilabelAUPRCCallback
espnet2.cls.lightning_callbacks.MultilabelAUPRCCallback
class espnet2.cls.lightning_callbacks.MultilabelAUPRCCallback
Bases: Callback
Computes and logs Multilabel AUPRC (mAP) at the end of each validation epoch.
To use this callback, you must implement a update_mAP method in the espnet model that accepts a MultilabelAUPRC object and calls its update method with predictions and targets. For example:
``
`
python class MyESPnetModel(AbsESPnetModel):
def update_mAP(self, mAP_function: MultilabelAUPRC): : … mAP_function.update(predictions, targets) …
``
` The model should also have a get_vocab_size() function that specifies the number of labels/classes.
compute_mAP(trainer)
Computes the mAP.
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
Called when the train batch ends.
NOTE
The value outputs["loss"]
here will be the normalized value w.r.t accumulate_grad_batches
of the loss returned from training_step
.
on_train_start(trainer, pl_module)
Called when the train begins.
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)
Called when the validation batch ends.
on_validation_epoch_end(trainer, pl_module)
Called when the val epoch ends.
on_validation_epoch_start(trainer, pl_module)
Called when the val epoch begins.
on_validation_start(trainer, pl_module)
Called when the validation loop begins.
setup_mAP(model)