|
"""Base Anomaly Module for Training Task.""" |
|
|
|
|
|
|
|
|
|
import importlib |
|
import logging |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from typing import TYPE_CHECKING, Any |
|
|
|
import lightning.pytorch as pl |
|
import torch |
|
from lightning.pytorch.trainer.states import TrainerFn |
|
from lightning.pytorch.utilities.types import STEP_OUTPUT |
|
from torch import nn |
|
from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform |
|
|
|
from anomalib import LearningType |
|
from anomalib.metrics import AnomalibMetricCollection |
|
from anomalib.metrics.threshold import BaseThreshold |
|
|
|
from .export_mixin import ExportMixin |
|
|
|
if TYPE_CHECKING: |
|
from lightning.pytorch.callbacks import Callback |
|
from torchmetrics import Metric |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AnomalyModule(ExportMixin, pl.LightningModule, ABC): |
|
"""AnomalyModule to train, validate, predict and test images. |
|
|
|
Acts as a base class for all the Anomaly Modules in the library. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
logger.info("Initializing %s model.", self.__class__.__name__) |
|
|
|
self.save_hyperparameters() |
|
self.model: nn.Module |
|
self.loss: nn.Module |
|
self.callbacks: list[Callback] |
|
|
|
self.image_threshold: BaseThreshold |
|
self.pixel_threshold: BaseThreshold |
|
|
|
self.normalization_metrics: Metric |
|
|
|
self.image_metrics: AnomalibMetricCollection |
|
self.pixel_metrics: AnomalibMetricCollection |
|
self.semantic_pixel_metrics: AnomalibMetricCollection |
|
|
|
self._transform: Transform | None = None |
|
self._input_size: tuple[int, int] | None = None |
|
|
|
self._is_setup = False |
|
|
|
@property |
|
def name(self) -> str: |
|
"""Name of the model.""" |
|
return self.__class__.__name__ |
|
|
|
def setup(self, stage: str | None = None) -> None: |
|
"""Calls the _setup method to build the model if the model is not already built.""" |
|
if getattr(self, "model", None) is None or not self._is_setup: |
|
self._setup() |
|
if isinstance(stage, TrainerFn): |
|
|
|
self._is_setup = True |
|
|
|
def _setup(self) -> None: |
|
"""The _setup method is used to build the torch model dynamically or adjust something about them. |
|
|
|
The model implementer may override this method to build the model. This is useful when the model cannot be set |
|
in the `__init__` method because it requires some information or data that is not available at the time of |
|
initialization. |
|
""" |
|
|
|
def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any: |
|
"""Perform the forward-pass by passing input tensor to the module. |
|
|
|
Args: |
|
batch (dict[str, str | torch.Tensor]): Input batch. |
|
*args: Arguments. |
|
**kwargs: Keyword arguments. |
|
|
|
Returns: |
|
Tensor: Output tensor from the model. |
|
""" |
|
del args, kwargs |
|
|
|
return self.model(batch) |
|
|
|
def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: |
|
"""To be implemented in the subclasses.""" |
|
raise NotImplementedError |
|
|
|
def predict_step( |
|
self, |
|
batch: dict[str, str | torch.Tensor], |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> STEP_OUTPUT: |
|
"""Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`. |
|
|
|
By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`. |
|
Override to add any processing logic. |
|
|
|
Args: |
|
batch (Any): Current batch |
|
batch_idx (int): Index of current batch |
|
dataloader_idx (int): Index of the current dataloader |
|
|
|
Return: |
|
Predicted output |
|
""" |
|
del dataloader_idx |
|
|
|
return self.validation_step(batch, batch_idx) |
|
|
|
def test_step(self, batch: dict[str, str | torch.Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT: |
|
"""Calls validation_step for anomaly map/score calculation. |
|
|
|
Args: |
|
batch (dict[str, str | torch.Tensor]): Input batch |
|
batch_idx (int): Batch index |
|
args: Arguments. |
|
kwargs: Keyword arguments. |
|
|
|
Returns: |
|
Dictionary containing images, features, true labels and masks. |
|
These are required in `validation_epoch_end` for feature concatenation. |
|
""" |
|
del args, kwargs |
|
|
|
return self.predict_step(batch, batch_idx) |
|
|
|
@property |
|
@abstractmethod |
|
def trainer_arguments(self) -> dict[str, Any]: |
|
"""Arguments used to override the trainer parameters so as to train the model correctly.""" |
|
raise NotImplementedError |
|
|
|
def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars: bool) -> None: |
|
if hasattr(self, "image_threshold"): |
|
destination[ |
|
"image_threshold_class" |
|
] = f"{self.image_threshold.__class__.__module__}.{self.image_threshold.__class__.__name__}" |
|
if hasattr(self, "pixel_threshold"): |
|
destination[ |
|
"pixel_threshold_class" |
|
] = f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}" |
|
if hasattr(self, "normalization_metrics"): |
|
normalization_class = self.normalization_metrics.__class__ |
|
destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}" |
|
|
|
return super()._save_to_state_dict(destination, prefix, keep_vars) |
|
|
|
def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: |
|
"""Initialize auxiliary object.""" |
|
if "image_threshold_class" in state_dict: |
|
self.image_threshold = self._get_instance(state_dict, "image_threshold_class") |
|
if "pixel_threshold_class" in state_dict: |
|
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class") |
|
if "normalization_class" in state_dict: |
|
self.normalization_metrics = self._get_instance(state_dict, "normalization_class") |
|
|
|
self._load_metrics(state_dict) |
|
|
|
return super().load_state_dict(state_dict, strict) |
|
|
|
def _load_metrics(self, state_dict: OrderedDict[str, torch.Tensor]) -> None: |
|
"""Load metrics from saved checkpoint.""" |
|
self._add_metrics("pixel", state_dict) |
|
self._add_metrics("image", state_dict) |
|
|
|
def _add_metrics(self, name: str, state_dict: OrderedDict[str, torch.Tensor]) -> None: |
|
"""Sets the pixel/image metrics. |
|
|
|
Args: |
|
name (str): is it pixel or image. |
|
state_dict (OrderedDict[str, Tensor]): state dict of the model. |
|
""" |
|
metric_keys = [key for key in state_dict if key.startswith(f"{name}_metrics")] |
|
if any(metric_keys): |
|
if not hasattr(self, f"{name}_metrics"): |
|
setattr(self, f"{name}_metrics", AnomalibMetricCollection([], prefix=f"{name}_")) |
|
metrics = getattr(self, f"{name}_metrics") |
|
for key in metric_keys: |
|
class_name = key.split(".")[1] |
|
try: |
|
metrics_module = importlib.import_module("anomalib.metrics") |
|
metrics_cls = getattr(metrics_module, class_name) |
|
except (ImportError, AttributeError) as exception: |
|
msg = f"Class {class_name} not found in module anomalib.metrics" |
|
raise ImportError(msg) from exception |
|
logger.info("Loading %s metrics from state dict", class_name) |
|
metrics.add_metrics(metrics_cls()) |
|
|
|
def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> BaseThreshold: |
|
"""Get the threshold class from the ``state_dict``.""" |
|
class_path = state_dict.pop(dict_key) |
|
module = importlib.import_module(".".join(class_path.split(".")[:-1])) |
|
return getattr(module, class_path.split(".")[-1])() |
|
|
|
@property |
|
@abstractmethod |
|
def learning_type(self) -> LearningType: |
|
"""Learning type of the model.""" |
|
raise NotImplementedError |
|
|
|
@property |
|
def transform(self) -> Transform: |
|
"""Retrieve the model-specific transform. |
|
|
|
If a transform has been set using `set_transform`, it will be returned. Otherwise, we will use the |
|
model-specific default transform, conditioned on the input size. |
|
""" |
|
return self._transform |
|
|
|
def set_transform(self, transform: Transform) -> None: |
|
"""Update the transform linked to the model instance.""" |
|
self._transform = transform |
|
|
|
def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: |
|
"""Default transforms. |
|
|
|
The default transform is resize to 256x256 and normalize to ImageNet stats. Individual models can override |
|
this method to provide custom transforms. |
|
""" |
|
logger.warning( |
|
"No implementation of `configure_transforms` was provided in the Lightning model. Using default " |
|
"transforms from the base class. This may not be suitable for your use case. Please override " |
|
"`configure_transforms` in your model.", |
|
) |
|
image_size = image_size or (256, 256) |
|
return Compose( |
|
[ |
|
Resize(image_size, antialias=True), |
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
], |
|
) |
|
|
|
@property |
|
def input_size(self) -> tuple[int, int] | None: |
|
"""Return the effective input size of the model. |
|
|
|
The effective input size is the size of the input tensor after the transform has been applied. If the transform |
|
is not set, or if the transform does not change the shape of the input tensor, this method will return None. |
|
""" |
|
transform = self.transform or self.configure_transforms() |
|
if transform is None: |
|
return None |
|
dummy_input = torch.zeros(1, 3, 1, 1) |
|
output_shape = transform(dummy_input).shape[-2:] |
|
if output_shape == (1, 1): |
|
return None |
|
return output_shape[-2:] |
|
|
|
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: |
|
"""Called when saving the model to a checkpoint. |
|
|
|
Saves the transform to the checkpoint. |
|
""" |
|
checkpoint["transform"] = self.transform |
|
|
|
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: |
|
"""Called when loading the model from a checkpoint. |
|
|
|
Loads the transform from the checkpoint and calls setup to ensure that the torch model is built before loading |
|
the state dict. |
|
""" |
|
self._transform = checkpoint["transform"] |
|
self.setup("load_checkpoint") |
|
|