"""Base Anomaly Module for Training Task.""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import importlib import logging from abc import ABC, abstractmethod from collections import OrderedDict from typing import TYPE_CHECKING, Any import lightning.pytorch as pl import torch from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform from anomalib import LearningType from anomalib.metrics import AnomalibMetricCollection from anomalib.metrics.threshold import BaseThreshold from .export_mixin import ExportMixin if TYPE_CHECKING: from lightning.pytorch.callbacks import Callback from torchmetrics import Metric logger = logging.getLogger(__name__) class AnomalyModule(ExportMixin, pl.LightningModule, ABC): """AnomalyModule to train, validate, predict and test images. Acts as a base class for all the Anomaly Modules in the library. """ def __init__(self) -> None: super().__init__() logger.info("Initializing %s model.", self.__class__.__name__) self.save_hyperparameters() self.model: nn.Module self.loss: nn.Module self.callbacks: list[Callback] self.image_threshold: BaseThreshold self.pixel_threshold: BaseThreshold self.normalization_metrics: Metric self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection self.semantic_pixel_metrics: AnomalibMetricCollection self._transform: Transform | None = None self._input_size: tuple[int, int] | None = None self._is_setup = False # flag to track if setup has been called from the trainer @property def name(self) -> str: """Name of the model.""" return self.__class__.__name__ def setup(self, stage: str | None = None) -> None: """Calls the _setup method to build the model if the model is not already built.""" if getattr(self, "model", None) is None or not self._is_setup: self._setup() if isinstance(stage, TrainerFn): # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer self._is_setup = True def _setup(self) -> None: """The _setup method is used to build the torch model dynamically or adjust something about them. The model implementer may override this method to build the model. This is useful when the model cannot be set in the `__init__` method because it requires some information or data that is not available at the time of initialization. """ def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any: # noqa: ANN401 """Perform the forward-pass by passing input tensor to the module. Args: batch (dict[str, str | torch.Tensor]): Input batch. *args: Arguments. **kwargs: Keyword arguments. Returns: Tensor: Output tensor from the model. """ del args, kwargs # These variables are not used. return self.model(batch) def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: """To be implemented in the subclasses.""" raise NotImplementedError def predict_step( self, batch: dict[str, str | torch.Tensor], batch_idx: int, dataloader_idx: int = 0, ) -> STEP_OUTPUT: """Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`. By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`. Override to add any processing logic. Args: batch (Any): Current batch batch_idx (int): Index of current batch dataloader_idx (int): Index of the current dataloader Return: Predicted output """ del dataloader_idx # These variables are not used. return self.validation_step(batch, batch_idx) def test_step(self, batch: dict[str, str | torch.Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT: """Calls validation_step for anomaly map/score calculation. Args: batch (dict[str, str | torch.Tensor]): Input batch batch_idx (int): Batch index args: Arguments. kwargs: Keyword arguments. Returns: Dictionary containing images, features, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """ del args, kwargs # These variables are not used. return self.predict_step(batch, batch_idx) @property @abstractmethod def trainer_arguments(self) -> dict[str, Any]: """Arguments used to override the trainer parameters so as to train the model correctly.""" raise NotImplementedError def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars: bool) -> None: if hasattr(self, "image_threshold"): destination[ "image_threshold_class" ] = f"{self.image_threshold.__class__.__module__}.{self.image_threshold.__class__.__name__}" if hasattr(self, "pixel_threshold"): destination[ "pixel_threshold_class" ] = f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}" if hasattr(self, "normalization_metrics"): normalization_class = self.normalization_metrics.__class__ destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}" return super()._save_to_state_dict(destination, prefix, keep_vars) def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401 """Initialize auxiliary object.""" if "image_threshold_class" in state_dict: self.image_threshold = self._get_instance(state_dict, "image_threshold_class") if "pixel_threshold_class" in state_dict: self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class") if "normalization_class" in state_dict: self.normalization_metrics = self._get_instance(state_dict, "normalization_class") # Used to load metrics if there is any related data in state_dict self._load_metrics(state_dict) return super().load_state_dict(state_dict, strict) def _load_metrics(self, state_dict: OrderedDict[str, torch.Tensor]) -> None: """Load metrics from saved checkpoint.""" self._add_metrics("pixel", state_dict) self._add_metrics("image", state_dict) def _add_metrics(self, name: str, state_dict: OrderedDict[str, torch.Tensor]) -> None: """Sets the pixel/image metrics. Args: name (str): is it pixel or image. state_dict (OrderedDict[str, Tensor]): state dict of the model. """ metric_keys = [key for key in state_dict if key.startswith(f"{name}_metrics")] if any(metric_keys): if not hasattr(self, f"{name}_metrics"): setattr(self, f"{name}_metrics", AnomalibMetricCollection([], prefix=f"{name}_")) metrics = getattr(self, f"{name}_metrics") for key in metric_keys: class_name = key.split(".")[1] try: metrics_module = importlib.import_module("anomalib.metrics") metrics_cls = getattr(metrics_module, class_name) except (ImportError, AttributeError) as exception: msg = f"Class {class_name} not found in module anomalib.metrics" raise ImportError(msg) from exception logger.info("Loading %s metrics from state dict", class_name) metrics.add_metrics(metrics_cls()) def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> BaseThreshold: """Get the threshold class from the ``state_dict``.""" class_path = state_dict.pop(dict_key) module = importlib.import_module(".".join(class_path.split(".")[:-1])) return getattr(module, class_path.split(".")[-1])() @property @abstractmethod def learning_type(self) -> LearningType: """Learning type of the model.""" raise NotImplementedError @property def transform(self) -> Transform: """Retrieve the model-specific transform. If a transform has been set using `set_transform`, it will be returned. Otherwise, we will use the model-specific default transform, conditioned on the input size. """ return self._transform def set_transform(self, transform: Transform) -> None: """Update the transform linked to the model instance.""" self._transform = transform def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: """Default transforms. The default transform is resize to 256x256 and normalize to ImageNet stats. Individual models can override this method to provide custom transforms. """ logger.warning( "No implementation of `configure_transforms` was provided in the Lightning model. Using default " "transforms from the base class. This may not be suitable for your use case. Please override " "`configure_transforms` in your model.", ) image_size = image_size or (256, 256) return Compose( [ Resize(image_size, antialias=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ], ) @property def input_size(self) -> tuple[int, int] | None: """Return the effective input size of the model. The effective input size is the size of the input tensor after the transform has been applied. If the transform is not set, or if the transform does not change the shape of the input tensor, this method will return None. """ transform = self.transform or self.configure_transforms() if transform is None: return None dummy_input = torch.zeros(1, 3, 1, 1) output_shape = transform(dummy_input).shape[-2:] if output_shape == (1, 1): return None return output_shape[-2:] def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Called when saving the model to a checkpoint. Saves the transform to the checkpoint. """ checkpoint["transform"] = self.transform def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Called when loading the model from a checkpoint. Loads the transform from the checkpoint and calls setup to ensure that the torch model is built before loading the state dict. """ self._transform = checkpoint["transform"] self.setup("load_checkpoint")