|
"""MetricsManager callback.""" |
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from enum import Enum |
|
from typing import Any |
|
|
|
import torch |
|
from lightning.pytorch import Callback, Trainer |
|
from lightning.pytorch.utilities.types import STEP_OUTPUT |
|
|
|
from anomalib import TaskType |
|
from anomalib.metrics import create_metric_collection |
|
from anomalib.models import AnomalyModule |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Device(str, Enum): |
|
"""Device on which to compute metrics.""" |
|
|
|
CPU = "cpu" |
|
GPU = "gpu" |
|
|
|
|
|
class _MetricsCallback(Callback): |
|
"""Create image and pixel-level AnomalibMetricsCollection. |
|
|
|
This callback creates AnomalibMetricsCollection based on the |
|
list of strings provided for image and pixel-level metrics. |
|
After these MetricCollections are created, the callback assigns |
|
these to the lightning module. |
|
|
|
Args: |
|
task (TaskType | str): Task type of the current run. |
|
image_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of image-level metrics. |
|
pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of pixel-level metrics. |
|
device (str): Whether to compute metrics on cpu or gpu. Defaults to cpu. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
task: TaskType | str = TaskType.SEGMENTATION, |
|
image_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, |
|
pixel_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, |
|
device: Device = Device.CPU, |
|
) -> None: |
|
super().__init__() |
|
self.task = TaskType(task) |
|
self.image_metric_names = image_metrics |
|
self.pixel_metric_names = pixel_metrics |
|
self.device = device |
|
|
|
def setup( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
stage: str | None = None, |
|
) -> None: |
|
"""Set image and pixel-level AnomalibMetricsCollection within Anomalib Model. |
|
|
|
Args: |
|
trainer (pl.Trainer): PyTorch Lightning Trainer |
|
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. |
|
stage (str | None, optional): fit, validate, test or predict. Defaults to None. |
|
""" |
|
del stage, trainer |
|
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names |
|
if isinstance(image_metric_names, str): |
|
image_metric_names = [image_metric_names] |
|
|
|
pixel_metric_names: list[str] | dict[str, dict[str, Any]] |
|
if self.pixel_metric_names is None: |
|
pixel_metric_names = [] |
|
elif self.task == TaskType.CLASSIFICATION: |
|
pixel_metric_names = [] |
|
logger.warning( |
|
"Cannot perform pixel-level evaluation when task type is classification. " |
|
"Ignoring the following pixel-level metrics: %s", |
|
self.pixel_metric_names, |
|
) |
|
else: |
|
pixel_metric_names = ( |
|
self.pixel_metric_names.copy() |
|
if not isinstance(self.pixel_metric_names, str) |
|
else [self.pixel_metric_names] |
|
) |
|
|
|
|
|
|
|
semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = [] |
|
|
|
if "SPRO" in pixel_metric_names: |
|
if isinstance(pixel_metric_names, list): |
|
pixel_metric_names.remove("SPRO") |
|
semantic_pixel_metric_names = ["SPRO"] |
|
elif isinstance(pixel_metric_names, dict): |
|
spro_metric = pixel_metric_names.pop("SPRO") |
|
semantic_pixel_metric_names = {"SPRO": spro_metric} |
|
else: |
|
logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names)) |
|
|
|
if isinstance(pl_module, AnomalyModule): |
|
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") |
|
if hasattr(pl_module, "pixel_metrics"): |
|
new_metrics = create_metric_collection(pixel_metric_names) |
|
for name in new_metrics: |
|
if name not in pl_module.pixel_metrics: |
|
pl_module.pixel_metrics.add_metrics(new_metrics[name]) |
|
else: |
|
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") |
|
pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_") |
|
self._set_threshold(pl_module) |
|
|
|
def on_validation_epoch_start( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
) -> None: |
|
del trainer |
|
|
|
pl_module.image_metrics.reset() |
|
pl_module.pixel_metrics.reset() |
|
pl_module.semantic_pixel_metrics.reset() |
|
|
|
def on_validation_batch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
outputs: STEP_OUTPUT | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
del trainer, batch, batch_idx, dataloader_idx |
|
|
|
if outputs is not None: |
|
self._outputs_to_device(outputs) |
|
self._update_metrics(pl_module, outputs) |
|
|
|
def on_validation_epoch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
) -> None: |
|
del trainer |
|
|
|
self._set_threshold(pl_module) |
|
self._log_metrics(pl_module) |
|
|
|
def on_test_epoch_start( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
) -> None: |
|
del trainer |
|
|
|
pl_module.image_metrics.reset() |
|
pl_module.pixel_metrics.reset() |
|
pl_module.semantic_pixel_metrics.reset() |
|
|
|
def on_test_batch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
outputs: STEP_OUTPUT | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
del trainer, batch, batch_idx, dataloader_idx |
|
|
|
if outputs is not None: |
|
self._outputs_to_device(outputs) |
|
self._update_metrics(pl_module, outputs) |
|
|
|
def on_test_epoch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
) -> None: |
|
del trainer |
|
|
|
self._log_metrics(pl_module) |
|
|
|
def _set_threshold(self, pl_module: AnomalyModule) -> None: |
|
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) |
|
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) |
|
pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) |
|
|
|
def _update_metrics( |
|
self, |
|
pl_module: AnomalyModule, |
|
output: STEP_OUTPUT, |
|
) -> None: |
|
pl_module.image_metrics.to(self.device) |
|
pl_module.image_metrics.update(output["pred_scores"], output["label"].int()) |
|
if "mask" in output and "anomaly_maps" in output: |
|
pl_module.pixel_metrics.to(self.device) |
|
pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) |
|
if "semantic_mask" in output and "anomaly_maps" in output: |
|
pl_module.semantic_pixel_metrics.to(self.device) |
|
pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) |
|
|
|
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: |
|
if isinstance(output, dict): |
|
for key, value in output.items(): |
|
output[key] = self._outputs_to_device(value) |
|
elif isinstance(output, torch.Tensor): |
|
output = output.to(self.device) |
|
elif isinstance(output, list): |
|
for i, value in enumerate(output): |
|
output[i] = self._outputs_to_device(value) |
|
return output |
|
|
|
@staticmethod |
|
def _log_metrics(pl_module: AnomalyModule) -> None: |
|
"""Log computed performance metrics.""" |
|
pl_module.log_dict(pl_module.image_metrics, prog_bar=True) |
|
if pl_module.pixel_metrics.update_called: |
|
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False) |
|
if pl_module.semantic_pixel_metrics.update_called: |
|
pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False) |
|
|