zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
"""Implementation of AUROC metric based on TorchMetrics."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from matplotlib.figure import Figure
from torchmetrics.classification.roc import BinaryROC
from torchmetrics.utilities.compute import auc
from .plotting_utils import plot_figure
class AUROC(BinaryROC):
"""Area under the ROC curve.
Examples:
>>> import torch
>>> from anomalib.metrics import AUROC
...
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.92, 0.03])
>>> target = torch.tensor([0, 0, 1, 1, 0])
...
>>> auroc = AUROC()
>>> auroc(preds, target)
tensor(0.6667)
It is possible to update the metric state incrementally:
>>> auroc.update(preds[:2], target[:2])
>>> auroc.update(preds[2:], target[2:])
>>> auroc.compute()
tensor(0.6667)
To plot the ROC curve, use the ``generate_figure`` method:
>>> fig, title = auroc.generate_figure()
"""
def compute(self) -> torch.Tensor:
"""First compute ROC curve, then compute area under the curve.
Returns:
Tensor: Value of the AUROC metric
"""
tpr: torch.Tensor
fpr: torch.Tensor
fpr, tpr = self._compute()
return auc(fpr, tpr, reorder=True)
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update state with new values.
Need to flatten new values as ROC expects them in this format for binary classification.
Args:
preds (torch.Tensor): predictions of the model
target (torch.Tensor): ground truth targets
"""
super().update(preds.flatten(), target.flatten())
def _compute(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute fpr/tpr value pairs.
Returns:
Tuple containing Tensors for fpr and tpr
"""
tpr: torch.Tensor
fpr: torch.Tensor
fpr, tpr, _thresholds = super().compute()
return (fpr, tpr)
def generate_figure(self) -> tuple[Figure, str]:
"""Generate a figure containing the ROC curve, the baseline and the AUROC.
Returns:
tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging
"""
fpr, tpr = self._compute()
auroc = self.compute()
xlim = (0.0, 1.0)
ylim = (0.0, 1.0)
xlabel = "False Positive Rate"
ylabel = "True Positive Rate"
loc = "lower right"
title = "ROC"
fig, axis = plot_figure(fpr, tpr, auroc, xlim, ylim, xlabel, ylabel, loc, title)
axis.plot(
[0, 1],
[0, 1],
color="navy",
lw=2,
linestyle="--",
figure=fig,
)
return fig, title