File size: 8,238 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Implementation of SPRO metric based on TorchMetrics."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from pathlib import Path
from typing import Any

import torch
from torchmetrics import Metric

from anomalib.data.utils import validate_path

logger = logging.getLogger(__name__)


class SPRO(Metric):
    """Saturated Per-Region Overlap (SPRO) Score.

    This metric computes the macro average of the saturated per-region overlap between the
    predicted anomaly masks and the ground truth masks.

    Args:
        threshold (float): Threshold used to binarize the predictions.
            Defaults to ``0.5``.
        saturation_config (str | Path): Path to the saturation configuration file.
            Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are
            separated by mask files.
        kwargs: Additional arguments to the TorchMetrics base class.

    Example:
        Import the metric from the package:

        >>> import torch
        >>> from anomalib.metrics import SPRO

        Create random ``preds`` and ``labels`` tensors:

        >>> labels = torch.randint(low=0, high=2, size=(2, 10, 5), dtype=torch.float32)
        >>> labels = [labels]
        >>> preds = torch.rand_like(labels[0][:1])

        Compute the SPRO score for labels and preds:

        >>> spro = SPRO(threshold=0.5)
        >>> spro.update(preds, labels)
        >>> spro.compute()
        tensor(0.6333)

        .. note::
            Note that the example above shows random predictions and labels.
            Therefore, the SPRO score above may not be reproducible.

    """

    def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None:
        super().__init__(**kwargs)
        self.threshold = threshold
        self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None
        if self.saturation_config is None:
            logger.warning(
                "The saturation_config attribute is empty, the threshold is set to the defect area."
                "This is equivalent to PRO metric but with the 'region' are separated by mask files",
            )
        self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None:
        """Compute the SPRO score for the current batch.

        Args:
            predictions (torch.Tensor): Predicted anomaly masks.
            masks (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the
                list is a tensor list of masks for the corresponding image.

        Example:
            To update the metric state for the current batch, use the ``update`` method:

            >>> spro.update(preds, labels)
        """
        score, total = spro_score(
            predictions=predictions,
            targets=masks,
            threshold=self.threshold,
            saturation_config=self.saturation_config,
        )
        self.score += score
        self.total += total

    def compute(self) -> torch.Tensor:
        """Compute the macro average of the SPRO score across all masks in all batches.

        Example:
            To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method:

            >>> spro.compute()
            tensor(0.5433)
        """
        if self.total == 0:  # only background/normal images
            return torch.Tensor([1.0])
        return self.score / self.total


def spro_score(
    predictions: torch.Tensor,
    targets: list[torch.Tensor],
    threshold: float = 0.5,
    saturation_config: dict | None = None,
) -> torch.Tensor:
    """Calculate the SPRO score for a batch of predictions.

    Args:
        predictions (torch.Tensor): Predicted anomaly masks.
        targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the
            list is a tensor list of masks for the corresponding image.
        threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions.
                Defaults: ``0.5``.
        saturation_config (dict): Saturations configuration for each label (pixel value) as the keys.
            Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are
            separated by mask files.

    Returns:
        torch.Tensor: Scalar value representing the average SPRO score for the input batch.
    """
    # Add batch dim if not exist
    if len(predictions.shape) == 2:
        predictions = predictions.unsqueeze(0)

    # Resize the prediction to have the same size as the target mask
    predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets[0].shape[-2:])

    # Apply threshold to binary predictions
    if predictions.dtype == torch.float:
        predictions = predictions > threshold

    score = torch.tensor(0.0)
    total = 0
    # Iterate for each image in the batch
    for i, target in enumerate(targets):
        # Iterate for each ground-truth mask per image
        for mask in target:
            label = torch.max(mask)
            if label == 0:  # Skip if only normal/background
                continue
            # Calculate true positive
            target_per_label = mask == label
            true_pos = torch.sum(predictions[i] & target_per_label)

            # Calculate the anomalous area of the ground-truth
            defect_area = torch.sum(target_per_label)

            if saturation_config is not None:
                # Adjust saturation threshold based on configuration
                saturation_per_label = saturation_config[label.int().item()]
                saturation_threshold = saturation_per_label["saturation_threshold"]

                if saturation_per_label["relative_saturation"]:
                    saturation_threshold *= defect_area

                # Check if threshold is larger than defect area
                if saturation_threshold > defect_area:
                    warning_msg = (
                        f"Saturation threshold for label {label.int().item()} is larger than defect area. "
                        "Setting it to defect area."
                    )
                    logger.warning(warning_msg)
                    saturation_threshold = defect_area
            else:
                # Handle case when saturation_config is empty
                saturation_threshold = defect_area

            # Update score with minimum of true_pos/saturation_threshold and 1.0
            score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0))
            total += 1
    return score, total


def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None:
    """Load saturation configurations from a JSON file.

    Args:
        config_path (str | Path): Path to the saturation configuration file.

    Returns:
        Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values.
            Return None if the config file is not found.

    Example JSON format in the config file of MVTec LOCO dataset:
    [
        {
            "defect_name": "1_additional_pushpin",
            "pixel_value": 255,
            "saturation_threshold": 6300,
            "relative_saturation": false
        },
        {
            "defect_name": "2_additional_pushpins",
            "pixel_value": 254,
            "saturation_threshold": 12600,
            "relative_saturation": false
        },
        ...
    ]
    """
    try:
        config_path = validate_path(config_path)
        with Path.open(config_path) as file:
            configs = json.load(file)
        # Create a dictionary with pixel values as keys
        return {conf["pixel_value"]: conf for conf in configs}
    except FileNotFoundError:
        logger.warning("The saturation config file %s does not exist. Returning None.", config_path)
        return None