LogSAD / anomalib /callbacks /normalization /min_max_normalization.py
zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
"""Anomaly Score Normalization Callback that uses min-max normalization."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from anomalib.metrics import MinMax
from anomalib.models.components import AnomalyModule
from anomalib.utils.normalization.min_max import normalize
from .base import NormalizationCallback
class _MinMaxNormalizationCallback(NormalizationCallback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization.
Note: This callback is set within the Engine.
"""
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
"""Add min_max metrics to normalization metrics."""
del trainer, stage # These variables are not used.
if not hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics = MinMax().cpu()
elif not isinstance(pl_module.normalization_metrics, MinMax):
msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
raise AttributeError(
msg,
)
def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
"""Call when the test begins."""
del trainer # `trainer` variable is not used.
for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics):
if metric is not None:
metric.set_threshold(0.5)
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the validation batch ends, update the min and max observed values."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
if "anomaly_maps" in outputs:
pl_module.normalization_metrics(outputs["anomaly_maps"])
elif "box_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
elif "pred_scores" in outputs:
pl_module.normalization_metrics(outputs["pred_scores"])
else:
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
raise ValueError(msg)
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the test batch ends, normalizes the predicted scores and anomaly maps."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
self._normalize_batch(outputs, pl_module)
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: Any, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Call when the predict batch ends, normalizes the predicted scores and anomaly maps."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
self._normalize_batch(outputs, pl_module)
@staticmethod
def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
stats = pl_module.normalization_metrics.cpu()
if "pred_scores" in outputs:
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
if "box_scores" in outputs:
outputs["box_scores"] = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
]