File size: 4,384 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""Anomaly Score Normalization Callback that uses min-max normalization."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from anomalib.metrics import MinMax
from anomalib.models.components import AnomalyModule
from anomalib.utils.normalization.min_max import normalize
from .base import NormalizationCallback
class _MinMaxNormalizationCallback(NormalizationCallback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization.
Note: This callback is set within the Engine.
"""
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
"""Add min_max metrics to normalization metrics."""
del trainer, stage # These variables are not used.
if not hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics = MinMax().cpu()
elif not isinstance(pl_module.normalization_metrics, MinMax):
msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
raise AttributeError(
msg,
)
def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
"""Call when the test begins."""
del trainer # `trainer` variable is not used.
for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics):
if metric is not None:
metric.set_threshold(0.5)
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the validation batch ends, update the min and max observed values."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
if "anomaly_maps" in outputs:
pl_module.normalization_metrics(outputs["anomaly_maps"])
elif "box_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
elif "pred_scores" in outputs:
pl_module.normalization_metrics(outputs["pred_scores"])
else:
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
raise ValueError(msg)
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the test batch ends, normalizes the predicted scores and anomaly maps."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
self._normalize_batch(outputs, pl_module)
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: Any, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the predict batch ends, normalizes the predicted scores and anomaly maps."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
self._normalize_batch(outputs, pl_module)
@staticmethod
def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
stats = pl_module.normalization_metrics.cpu()
if "pred_scores" in outputs:
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
if "box_scores" in outputs:
outputs["box_scores"] = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
]
|