File size: 4,965 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""Callback that attaches necessary pre/post-processing to the model."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import torch
from lightning import Callback
from lightning.pytorch import Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.models import AnomalyModule
class _PostProcessorCallback(Callback):
"""Applies post-processing to the model outputs.
Note: This callback is set within the Engine.
"""
def __init__(self) -> None:
super().__init__()
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
del batch, batch_idx, dataloader_idx # Unused arguments.
if outputs is not None:
self.post_process(trainer, pl_module, outputs)
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
del batch, batch_idx, dataloader_idx # Unused arguments.
if outputs is not None:
self.post_process(trainer, pl_module, outputs)
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: Any, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
del batch, batch_idx, dataloader_idx # Unused arguments.
if outputs is not None:
self.post_process(trainer, pl_module, outputs)
def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
if isinstance(outputs, dict):
self._post_process(outputs)
if trainer.predicting or trainer.testing:
self._compute_scores_and_labels(pl_module, outputs)
@staticmethod
def _compute_scores_and_labels(
pl_module: AnomalyModule,
outputs: dict[str, Any],
) -> None:
if "pred_scores" in outputs:
outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value
if "anomaly_maps" in outputs:
outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value
if "pred_boxes" not in outputs:
outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes(
outputs["pred_masks"],
outputs["anomaly_maps"],
)
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
# apply thresholding to boxes
if "box_scores" in outputs and "box_labels" not in outputs:
# apply threshold to assign normal/anomalous label to boxes
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]]
outputs["box_labels"] = [labels.int() for labels in is_anomalous]
@staticmethod
def _post_process(outputs: STEP_OUTPUT) -> None:
"""Compute labels based on model predictions."""
if isinstance(outputs, dict):
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
# infer image scores from anomaly maps
outputs["pred_scores"] = (
outputs["anomaly_maps"] # noqa: PD011
.reshape(outputs["anomaly_maps"].shape[0], -1)
.max(dim=1)
.values
)
elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs:
# infer image score from bbox confidence scores
outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float()
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)):
if boxes.numel():
outputs["pred_scores"][idx] = scores.max().item()
if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
# create anomaly maps from bbox predictions for thresholding and evaluation
image_size: tuple[int, int] = outputs["image"].shape[-2:]
pred_boxes: torch.Tensor = outputs["pred_boxes"]
box_scores: torch.Tensor = outputs["box_scores"]
outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size)
if "boxes" in outputs:
true_boxes: list[torch.Tensor] = outputs["boxes"]
outputs["mask"] = boxes_to_masks(true_boxes, image_size)
|