|
"""Implementation of SPRO metric based on TorchMetrics.""" |
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import torch |
|
from torchmetrics import Metric |
|
|
|
from anomalib.data.utils import validate_path |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SPRO(Metric): |
|
"""Saturated Per-Region Overlap (SPRO) Score. |
|
|
|
This metric computes the macro average of the saturated per-region overlap between the |
|
predicted anomaly masks and the ground truth masks. |
|
|
|
Args: |
|
threshold (float): Threshold used to binarize the predictions. |
|
Defaults to ``0.5``. |
|
saturation_config (str | Path): Path to the saturation configuration file. |
|
Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are |
|
separated by mask files. |
|
kwargs: Additional arguments to the TorchMetrics base class. |
|
|
|
Example: |
|
Import the metric from the package: |
|
|
|
>>> import torch |
|
>>> from anomalib.metrics import SPRO |
|
|
|
Create random ``preds`` and ``labels`` tensors: |
|
|
|
>>> labels = torch.randint(low=0, high=2, size=(2, 10, 5), dtype=torch.float32) |
|
>>> labels = [labels] |
|
>>> preds = torch.rand_like(labels[0][:1]) |
|
|
|
Compute the SPRO score for labels and preds: |
|
|
|
>>> spro = SPRO(threshold=0.5) |
|
>>> spro.update(preds, labels) |
|
>>> spro.compute() |
|
tensor(0.6333) |
|
|
|
.. note:: |
|
Note that the example above shows random predictions and labels. |
|
Therefore, the SPRO score above may not be reproducible. |
|
|
|
""" |
|
|
|
def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.threshold = threshold |
|
self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None |
|
if self.saturation_config is None: |
|
logger.warning( |
|
"The saturation_config attribute is empty, the threshold is set to the defect area." |
|
"This is equivalent to PRO metric but with the 'region' are separated by mask files", |
|
) |
|
self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") |
|
|
|
def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None: |
|
"""Compute the SPRO score for the current batch. |
|
|
|
Args: |
|
predictions (torch.Tensor): Predicted anomaly masks. |
|
masks (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the |
|
list is a tensor list of masks for the corresponding image. |
|
|
|
Example: |
|
To update the metric state for the current batch, use the ``update`` method: |
|
|
|
>>> spro.update(preds, labels) |
|
""" |
|
score, total = spro_score( |
|
predictions=predictions, |
|
targets=masks, |
|
threshold=self.threshold, |
|
saturation_config=self.saturation_config, |
|
) |
|
self.score += score |
|
self.total += total |
|
|
|
def compute(self) -> torch.Tensor: |
|
"""Compute the macro average of the SPRO score across all masks in all batches. |
|
|
|
Example: |
|
To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: |
|
|
|
>>> spro.compute() |
|
tensor(0.5433) |
|
""" |
|
if self.total == 0: |
|
return torch.Tensor([1.0]) |
|
return self.score / self.total |
|
|
|
|
|
def spro_score( |
|
predictions: torch.Tensor, |
|
targets: list[torch.Tensor], |
|
threshold: float = 0.5, |
|
saturation_config: dict | None = None, |
|
) -> torch.Tensor: |
|
"""Calculate the SPRO score for a batch of predictions. |
|
|
|
Args: |
|
predictions (torch.Tensor): Predicted anomaly masks. |
|
targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the |
|
list is a tensor list of masks for the corresponding image. |
|
threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. |
|
Defaults: ``0.5``. |
|
saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. |
|
Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are |
|
separated by mask files. |
|
|
|
Returns: |
|
torch.Tensor: Scalar value representing the average SPRO score for the input batch. |
|
""" |
|
|
|
if len(predictions.shape) == 2: |
|
predictions = predictions.unsqueeze(0) |
|
|
|
|
|
predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets[0].shape[-2:]) |
|
|
|
|
|
if predictions.dtype == torch.float: |
|
predictions = predictions > threshold |
|
|
|
score = torch.tensor(0.0) |
|
total = 0 |
|
|
|
for i, target in enumerate(targets): |
|
|
|
for mask in target: |
|
label = torch.max(mask) |
|
if label == 0: |
|
continue |
|
|
|
target_per_label = mask == label |
|
true_pos = torch.sum(predictions[i] & target_per_label) |
|
|
|
|
|
defect_area = torch.sum(target_per_label) |
|
|
|
if saturation_config is not None: |
|
|
|
saturation_per_label = saturation_config[label.int().item()] |
|
saturation_threshold = saturation_per_label["saturation_threshold"] |
|
|
|
if saturation_per_label["relative_saturation"]: |
|
saturation_threshold *= defect_area |
|
|
|
|
|
if saturation_threshold > defect_area: |
|
warning_msg = ( |
|
f"Saturation threshold for label {label.int().item()} is larger than defect area. " |
|
"Setting it to defect area." |
|
) |
|
logger.warning(warning_msg) |
|
saturation_threshold = defect_area |
|
else: |
|
|
|
saturation_threshold = defect_area |
|
|
|
|
|
score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) |
|
total += 1 |
|
return score, total |
|
|
|
|
|
def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: |
|
"""Load saturation configurations from a JSON file. |
|
|
|
Args: |
|
config_path (str | Path): Path to the saturation configuration file. |
|
|
|
Returns: |
|
Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. |
|
Return None if the config file is not found. |
|
|
|
Example JSON format in the config file of MVTec LOCO dataset: |
|
[ |
|
{ |
|
"defect_name": "1_additional_pushpin", |
|
"pixel_value": 255, |
|
"saturation_threshold": 6300, |
|
"relative_saturation": false |
|
}, |
|
{ |
|
"defect_name": "2_additional_pushpins", |
|
"pixel_value": 254, |
|
"saturation_threshold": 12600, |
|
"relative_saturation": false |
|
}, |
|
... |
|
] |
|
""" |
|
try: |
|
config_path = validate_path(config_path) |
|
with Path.open(config_path) as file: |
|
configs = json.load(file) |
|
|
|
return {conf["pixel_value"]: conf for conf in configs} |
|
except FileNotFoundError: |
|
logger.warning("The saturation config file %s does not exist. Returning None.", config_path) |
|
return None |
|
|