LogSAD / anomalib /metrics /plotting_utils.py
zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
"""Helper functions to generate ROC-style plots of various metrics."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from matplotlib import pyplot as plt
from matplotlib.axis import Axis
from matplotlib.figure import Figure
def plot_figure(
x_vals: torch.Tensor,
y_vals: torch.Tensor,
auc: torch.Tensor,
xlim: tuple[float, float],
ylim: tuple[float, float],
xlabel: str,
ylabel: str,
loc: str,
title: str,
sample_points: int = 1000,
) -> tuple[Figure, Axis]:
"""Generate a simple, ROC-style plot, where x_vals is plotted against y_vals.
Note that a subsampling is applied if > sample_points are present in x/y, as matplotlib plotting draws
every single plot which takes very long, especially for high-resolution segmentations.
Args:
x_vals (torch.Tensor): x values to plot
y_vals (torch.Tensor): y values to plot
auc (torch.Tensor): normalized area under the curve spanned by x_vals, y_vals
xlim (tuple[float, float]): displayed range for x-axis
ylim (tuple[float, float]): displayed range for y-axis
xlabel (str): label of x axis
ylabel (str): label of y axis
loc (str): string-based legend location, for details see
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html
title (str): title of the plot
sample_points (int): number of sampling points to subsample x_vals/y_vals with
Defaults to ``1000``.
Returns:
tuple[Figure, Axis]: Figure and the contained Axis
"""
fig, axis = plt.subplots()
x_vals = x_vals.detach().cpu()
y_vals = y_vals.detach().cpu()
if sample_points < x_vals.size(0):
possible_idx = range(x_vals.size(0))
interval = len(possible_idx) // sample_points
idx = [0] # make sure to start at first point
idx.extend(possible_idx[::interval])
idx.append(possible_idx[-1]) # also include last point
idx = torch.tensor(
idx,
device=x_vals.device,
)
x_vals = torch.index_select(x_vals, 0, idx)
y_vals = torch.index_select(y_vals, 0, idx)
axis.plot(
x_vals,
y_vals,
color="darkorange",
figure=fig,
lw=2,
label=f"AUC: {auc.detach().cpu():0.2f}",
)
axis.set_xlim(xlim)
axis.set_ylim(ylim)
axis.set_xlabel(xlabel)
axis.set_ylabel(ylabel)
axis.legend(loc=loc)
axis.set_title(title)
return fig, axis