File size: 7,400 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
"""Thresholding callback."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import importlib
from typing import Any
import torch
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig, ListConfig
from anomalib.metrics.threshold import BaseThreshold
from anomalib.models import AnomalyModule
from anomalib.utils.types import THRESHOLD
class _ThresholdCallback(Callback):
"""Setup/apply thresholding.
Note: This callback is set within the Engine.
"""
def __init__(
self,
threshold: THRESHOLD = "F1AdaptiveThreshold",
) -> None:
super().__init__()
self._initialize_thresholds(threshold)
self.image_threshold: BaseThreshold
self.pixel_threshold: BaseThreshold
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str) -> None:
del trainer, stage # Unused arguments.
if not hasattr(pl_module, "image_threshold"):
pl_module.image_threshold = self.image_threshold
if not hasattr(pl_module, "pixel_threshold"):
pl_module.pixel_threshold = self.pixel_threshold
def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
del trainer # Unused argument.
self._reset(pl_module)
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
if outputs is not None:
self._outputs_to_cpu(outputs)
self._update(pl_module, outputs)
def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
del trainer # Unused argument.
self._compute(pl_module)
def _initialize_thresholds(
self,
threshold: THRESHOLD,
) -> None:
"""Initialize ``self.image_threshold`` and ``self.pixel_threshold``.
Args:
threshold (THRESHOLD):
Threshold configuration
Example:
>>> _initialize_thresholds(F1AdaptiveThreshold())
or
>>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
or configuration
For more details on configuration see :fun:`_load_from_config`
Raises:
ValueError: Unknown threshold class or incorrect configuration
"""
# TODO(djdameln): Add tests for each case
# CVS-122661
# When only a single threshold class is passed.
# This initializes image and pixel thresholds with the same class
# >>> _initialize_thresholds(F1AdaptiveThreshold())
if isinstance(threshold, BaseThreshold):
self.image_threshold = threshold
self.pixel_threshold = threshold.clone()
# When a tuple of threshold classes are passed
# >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
elif isinstance(threshold, tuple) and isinstance(threshold[0], BaseThreshold):
self.image_threshold = threshold[0]
self.pixel_threshold = threshold[1]
# When the passed threshold is not an instance of a Threshold class.
elif isinstance(threshold, str | DictConfig | ListConfig | list):
self._load_from_config(threshold)
else:
msg = f"Invalid threshold type {type(threshold)}"
raise TypeError(msg)
def _load_from_config(self, threshold: DictConfig | str | ListConfig | list[dict[str, str | float]]) -> None:
"""Load the thresholding class based on the config.
Example:
threshold: F1AdaptiveThreshold
or
threshold:
class_path: F1AdaptiveThreshold
init_args:
-
or
threshold:
- F1AdaptiveThreshold
- F1AdaptiveThreshold
or
threshold:
- class_path: F1AdaptiveThreshold
init_args:
-
- class_path: F1AdaptiveThreshold
"""
if isinstance(threshold, str | DictConfig):
self.image_threshold = self._get_threshold_from_config(threshold)
self.pixel_threshold = self.image_threshold.clone()
elif isinstance(threshold, ListConfig | list):
self.image_threshold = self._get_threshold_from_config(threshold[0])
self.pixel_threshold = self._get_threshold_from_config(threshold[1])
else:
msg = f"Invalid threshold config {threshold}"
raise TypeError(msg)
def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> BaseThreshold:
"""Return the instantiated threshold object.
Example:
>>> _get_threshold_from_config(F1AdaptiveThreshold)
or
>>> config = DictConfig({
... "class_path": "ManualThreshold",
... "init_args": {"default_value": 0.7}
... })
>>> __get_threshold_from_config(config)
or
>>> config = DictConfig({
... "class_path": "anomalib.metrics.threshold.F1AdaptiveThreshold"
... })
>>> __get_threshold_from_config(config)
Returns:
(BaseThreshold): Instance of threshold object.
"""
if isinstance(threshold, str):
threshold = DictConfig({"class_path": threshold})
class_path = threshold["class_path"]
init_args = threshold.get("init_args", {})
if len(class_path.split(".")) == 1:
module_path = "anomalib.metrics.threshold"
else:
module_path = ".".join(class_path.split(".")[:-1])
class_path = class_path.split(".")[-1]
module = importlib.import_module(module_path)
class_ = getattr(module, class_path)
return class_(**init_args)
def _reset(self, pl_module: AnomalyModule) -> None:
pl_module.image_threshold.reset()
pl_module.pixel_threshold.reset()
def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_cpu(value)
elif isinstance(output, torch.Tensor):
output = output.cpu()
return output
def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
pl_module.image_threshold.cpu()
pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int())
if "mask" in outputs and "anomaly_maps" in outputs:
pl_module.pixel_threshold.cpu()
pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int())
def _compute(self, pl_module: AnomalyModule) -> None:
pl_module.image_threshold.compute()
if pl_module.pixel_threshold._update_called: # noqa: SLF001
pl_module.pixel_threshold.compute()
else:
pl_module.pixel_threshold.value = pl_module.image_threshold.value
|