File size: 11,252 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
"""Base Anomaly Module for Training Task."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import importlib
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import TYPE_CHECKING, Any
import lightning.pytorch as pl
import torch
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform
from anomalib import LearningType
from anomalib.metrics import AnomalibMetricCollection
from anomalib.metrics.threshold import BaseThreshold
from .export_mixin import ExportMixin
if TYPE_CHECKING:
from lightning.pytorch.callbacks import Callback
from torchmetrics import Metric
logger = logging.getLogger(__name__)
class AnomalyModule(ExportMixin, pl.LightningModule, ABC):
"""AnomalyModule to train, validate, predict and test images.
Acts as a base class for all the Anomaly Modules in the library.
"""
def __init__(self) -> None:
super().__init__()
logger.info("Initializing %s model.", self.__class__.__name__)
self.save_hyperparameters()
self.model: nn.Module
self.loss: nn.Module
self.callbacks: list[Callback]
self.image_threshold: BaseThreshold
self.pixel_threshold: BaseThreshold
self.normalization_metrics: Metric
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection
self.semantic_pixel_metrics: AnomalibMetricCollection
self._transform: Transform | None = None
self._input_size: tuple[int, int] | None = None
self._is_setup = False # flag to track if setup has been called from the trainer
@property
def name(self) -> str:
"""Name of the model."""
return self.__class__.__name__
def setup(self, stage: str | None = None) -> None:
"""Calls the _setup method to build the model if the model is not already built."""
if getattr(self, "model", None) is None or not self._is_setup:
self._setup()
if isinstance(stage, TrainerFn):
# only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer
self._is_setup = True
def _setup(self) -> None:
"""The _setup method is used to build the torch model dynamically or adjust something about them.
The model implementer may override this method to build the model. This is useful when the model cannot be set
in the `__init__` method because it requires some information or data that is not available at the time of
initialization.
"""
def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any: # noqa: ANN401
"""Perform the forward-pass by passing input tensor to the module.
Args:
batch (dict[str, str | torch.Tensor]): Input batch.
*args: Arguments.
**kwargs: Keyword arguments.
Returns:
Tensor: Output tensor from the model.
"""
del args, kwargs # These variables are not used.
return self.model(batch)
def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
"""To be implemented in the subclasses."""
raise NotImplementedError
def predict_step(
self,
batch: dict[str, str | torch.Tensor],
batch_idx: int,
dataloader_idx: int = 0,
) -> STEP_OUTPUT:
"""Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`.
By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`.
Override to add any processing logic.
Args:
batch (Any): Current batch
batch_idx (int): Index of current batch
dataloader_idx (int): Index of the current dataloader
Return:
Predicted output
"""
del dataloader_idx # These variables are not used.
return self.validation_step(batch, batch_idx)
def test_step(self, batch: dict[str, str | torch.Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT:
"""Calls validation_step for anomaly map/score calculation.
Args:
batch (dict[str, str | torch.Tensor]): Input batch
batch_idx (int): Batch index
args: Arguments.
kwargs: Keyword arguments.
Returns:
Dictionary containing images, features, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.
"""
del args, kwargs # These variables are not used.
return self.predict_step(batch, batch_idx)
@property
@abstractmethod
def trainer_arguments(self) -> dict[str, Any]:
"""Arguments used to override the trainer parameters so as to train the model correctly."""
raise NotImplementedError
def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars: bool) -> None:
if hasattr(self, "image_threshold"):
destination[
"image_threshold_class"
] = f"{self.image_threshold.__class__.__module__}.{self.image_threshold.__class__.__name__}"
if hasattr(self, "pixel_threshold"):
destination[
"pixel_threshold_class"
] = f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}"
if hasattr(self, "normalization_metrics"):
normalization_class = self.normalization_metrics.__class__
destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}"
return super()._save_to_state_dict(destination, prefix, keep_vars)
def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401
"""Initialize auxiliary object."""
if "image_threshold_class" in state_dict:
self.image_threshold = self._get_instance(state_dict, "image_threshold_class")
if "pixel_threshold_class" in state_dict:
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class")
if "normalization_class" in state_dict:
self.normalization_metrics = self._get_instance(state_dict, "normalization_class")
# Used to load metrics if there is any related data in state_dict
self._load_metrics(state_dict)
return super().load_state_dict(state_dict, strict)
def _load_metrics(self, state_dict: OrderedDict[str, torch.Tensor]) -> None:
"""Load metrics from saved checkpoint."""
self._add_metrics("pixel", state_dict)
self._add_metrics("image", state_dict)
def _add_metrics(self, name: str, state_dict: OrderedDict[str, torch.Tensor]) -> None:
"""Sets the pixel/image metrics.
Args:
name (str): is it pixel or image.
state_dict (OrderedDict[str, Tensor]): state dict of the model.
"""
metric_keys = [key for key in state_dict if key.startswith(f"{name}_metrics")]
if any(metric_keys):
if not hasattr(self, f"{name}_metrics"):
setattr(self, f"{name}_metrics", AnomalibMetricCollection([], prefix=f"{name}_"))
metrics = getattr(self, f"{name}_metrics")
for key in metric_keys:
class_name = key.split(".")[1]
try:
metrics_module = importlib.import_module("anomalib.metrics")
metrics_cls = getattr(metrics_module, class_name)
except (ImportError, AttributeError) as exception:
msg = f"Class {class_name} not found in module anomalib.metrics"
raise ImportError(msg) from exception
logger.info("Loading %s metrics from state dict", class_name)
metrics.add_metrics(metrics_cls())
def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> BaseThreshold:
"""Get the threshold class from the ``state_dict``."""
class_path = state_dict.pop(dict_key)
module = importlib.import_module(".".join(class_path.split(".")[:-1]))
return getattr(module, class_path.split(".")[-1])()
@property
@abstractmethod
def learning_type(self) -> LearningType:
"""Learning type of the model."""
raise NotImplementedError
@property
def transform(self) -> Transform:
"""Retrieve the model-specific transform.
If a transform has been set using `set_transform`, it will be returned. Otherwise, we will use the
model-specific default transform, conditioned on the input size.
"""
return self._transform
def set_transform(self, transform: Transform) -> None:
"""Update the transform linked to the model instance."""
self._transform = transform
def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform:
"""Default transforms.
The default transform is resize to 256x256 and normalize to ImageNet stats. Individual models can override
this method to provide custom transforms.
"""
logger.warning(
"No implementation of `configure_transforms` was provided in the Lightning model. Using default "
"transforms from the base class. This may not be suitable for your use case. Please override "
"`configure_transforms` in your model.",
)
image_size = image_size or (256, 256)
return Compose(
[
Resize(image_size, antialias=True),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)
@property
def input_size(self) -> tuple[int, int] | None:
"""Return the effective input size of the model.
The effective input size is the size of the input tensor after the transform has been applied. If the transform
is not set, or if the transform does not change the shape of the input tensor, this method will return None.
"""
transform = self.transform or self.configure_transforms()
if transform is None:
return None
dummy_input = torch.zeros(1, 3, 1, 1)
output_shape = transform(dummy_input).shape[-2:]
if output_shape == (1, 1):
return None
return output_shape[-2:]
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Called when saving the model to a checkpoint.
Saves the transform to the checkpoint.
"""
checkpoint["transform"] = self.transform
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Called when loading the model from a checkpoint.
Loads the transform from the checkpoint and calls setup to ensure that the torch model is built before loading
the state dict.
"""
self._transform = checkpoint["transform"]
self.setup("load_checkpoint")
|