diff --git a/anomalib/__init__.py b/anomalib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..711eb023e9eec4c99444cad842309e019b6d6482 --- /dev/null +++ b/anomalib/__init__.py @@ -0,0 +1,24 @@ +"""Anomalib library for research and benchmarking.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + +__version__ = "1.1.0dev" + + +class LearningType(str, Enum): + """Learning type defining how the model learns from the dataset samples.""" + + ONE_CLASS = "one_class" + ZERO_SHOT = "zero_shot" + FEW_SHOT = "few_shot" + + +class TaskType(str, Enum): + """Task type used when generating predictions on the dataset.""" + + CLASSIFICATION = "classification" + DETECTION = "detection" + SEGMENTATION = "segmentation" diff --git a/anomalib/callbacks/__init__.py b/anomalib/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12ec54d8f355ec22e5a8185a6d162d39ad7a02c2 --- /dev/null +++ b/anomalib/callbacks/__init__.py @@ -0,0 +1,64 @@ +"""Callbacks for Anomalib models.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from importlib import import_module +from pathlib import Path + +import yaml +from jsonargparse import Namespace +from lightning.pytorch.callbacks import Callback +from omegaconf import DictConfig, ListConfig, OmegaConf + +from .checkpoint import ModelCheckpoint +from .graph import GraphLogger +from .model_loader import LoadModelCallback +from .tiler_configuration import TilerConfigurationCallback +from .timer import TimerCallback + +__all__ = [ + "ModelCheckpoint", + "GraphLogger", + "LoadModelCallback", + "TilerConfigurationCallback", + "TimerCallback", +] + + +logger = logging.getLogger(__name__) + + +def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]: + """Return base callbacks for all the lightning models. + + Args: + config (DictConfig | ListConfig | Namespace): Model config + + Return: + (list[Callback]): List of callbacks. + """ + logger.info("Loading the callbacks") + + callbacks: list[Callback] = [] + + if "ckpt_path" in config.trainer and config.ckpt_path is not None: + load_model = LoadModelCallback(config.ckpt_path) + callbacks.append(load_model) + + if "optimization" in config and "nncf" in config.optimization and config.optimization.nncf.apply: + # NNCF wraps torch's jit which conflicts with kornia's jit calls. + # Hence, nncf is imported only when required + nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") + nncf_callback = nncf_module.NNCFCallback + nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) + callbacks.append( + nncf_callback( + config=nncf_config, + export_dir=str(Path(config.project.path) / "compressed"), + ), + ) + + return callbacks diff --git a/anomalib/callbacks/checkpoint.py b/anomalib/callbacks/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d4af9dfa8ecb410345126ddf601c6a7a3d3369b8 --- /dev/null +++ b/anomalib/callbacks/checkpoint.py @@ -0,0 +1,58 @@ +"""Anomalib Model Checkpoint Callback.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint as LightningCheckpoint +from lightning.pytorch.trainer.states import TrainerFn + +from anomalib import LearningType + + +class ModelCheckpoint(LightningCheckpoint): + """Anomalib Model Checkpoint Callback. + + This class overrides the Lightning ModelCheckpoint callback to enable saving checkpoints without running any + training steps. This is useful for zero-/few-shot models, where the fit sequence only consists of validation. + + To enable saving checkpoints without running any training steps, we need to override two checks which are being + called in the ``on_validation_end`` method of the parent class: + - ``_should_save_on_train_epoch_end``: This method checks whether the checkpoint should be saved at the end of a + training epoch, or at the end of the validation sequence. We modify this method to default to saving at the end + of the validation sequence when the model is of zero- or few-shot type, unless ``save_on_train_epoch_end`` is + specifically set by the user. + - ``_should_skip_saving_checkpoint``: This method checks whether the checkpoint should be saved at all. We modify + this method to allow saving during both the ``FITTING`` and ``VALIDATING`` states. In addition, we allow saving + if the global step has not changed since the last checkpoint, but only for zero- and few-shot models. This is + needed because both the last global step and the last checkpoint remain unchanged during zero-/few-shot + training, which would otherwise prevent saving checkpoints during validation. + """ + + def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool: + """Checks whether the checkpoint should be saved. + + Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow + saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models). + """ + is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT] + return ( + bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run + or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or (self._last_global_step_saved == trainer.global_step and not is_zero_or_few_shot) + ) + + def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool: + """Checks whether the checkpoint should be saved at the end of a training epoch or validation sequence. + + Overrides the parent method to default to saving at the end of the validation sequence when the model is of + zero- or few-shot type, unless ``save_on_train_epoch_end`` is specifically set by the user. + """ + if self._save_on_train_epoch_end is not None: + return self._save_on_train_epoch_end + + if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: + return False + + return super()._should_save_on_train_epoch_end(trainer) diff --git a/anomalib/callbacks/graph.py b/anomalib/callbacks/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f27e3a997df51c78a01dfa59bdbe6c0c1fc7e4 --- /dev/null +++ b/anomalib/callbacks/graph.py @@ -0,0 +1,61 @@ +"""Log model graph to respective logger.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from lightning.pytorch import Callback, LightningModule, Trainer + +from anomalib.loggers import AnomalibCometLogger, AnomalibTensorBoardLogger, AnomalibWandbLogger + + +class GraphLogger(Callback): + """Log model graph to respective logger. + + Examples: + Log model graph to Tensorboard + + >>> from anomalib.callbacks import GraphLogger + >>> from anomalib.loggers import AnomalibTensorBoardLogger + >>> from anomalib.engine import Engine + ... + >>> logger = AnomalibTensorBoardLogger() + >>> callbacks = [GraphLogger()] + >>> engine = Engine(logger=logger, callbacks=callbacks) + + Log model graph to Comet + + >>> from anomalib.loggers import AnomalibCometLogger + >>> from anomalib.engine import Engine + ... + >>> logger = AnomalibCometLogger() + >>> callbacks = [GraphLogger()] + >>> engine = Engine(logger=logger, callbacks=callbacks) + """ + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Log model graph to respective logger. + + Args: + trainer: Trainer object which contans reference to loggers. + pl_module: LightningModule object which is logged. + """ + for logger in trainer.loggers: + if isinstance(logger, AnomalibWandbLogger): + # NOTE: log graph gets populated only after one backward pass. This won't work for models which do not + # require training such as Padim + logger.watch(pl_module, log_graph=True, log="all") + break + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Unwatch model if configured for wandb and log it model graph in Tensorboard if specified. + + Args: + trainer: Trainer object which contans reference to loggers. + pl_module: LightningModule object which is logged. + """ + for logger in trainer.loggers: + if isinstance(logger, AnomalibCometLogger | AnomalibTensorBoardLogger): + logger.log_graph(pl_module, input_array=torch.ones((1, 3, 256, 256))) + elif isinstance(logger, AnomalibWandbLogger): + logger.experiment.unwatch(pl_module) diff --git a/anomalib/callbacks/metrics.py b/anomalib/callbacks/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..8c32e6ec40949a6f5963dc8d0721bc1e4dc86708 --- /dev/null +++ b/anomalib/callbacks/metrics.py @@ -0,0 +1,226 @@ +"""MetricsManager callback.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from enum import Enum +from typing import Any + +import torch +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import TaskType +from anomalib.metrics import create_metric_collection +from anomalib.models import AnomalyModule + +logger = logging.getLogger(__name__) + + +class Device(str, Enum): + """Device on which to compute metrics.""" + + CPU = "cpu" + GPU = "gpu" + + +class _MetricsCallback(Callback): + """Create image and pixel-level AnomalibMetricsCollection. + + This callback creates AnomalibMetricsCollection based on the + list of strings provided for image and pixel-level metrics. + After these MetricCollections are created, the callback assigns + these to the lightning module. + + Args: + task (TaskType | str): Task type of the current run. + image_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of image-level metrics. + pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of pixel-level metrics. + device (str): Whether to compute metrics on cpu or gpu. Defaults to cpu. + """ + + def __init__( + self, + task: TaskType | str = TaskType.SEGMENTATION, + image_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, + pixel_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, + device: Device = Device.CPU, + ) -> None: + super().__init__() + self.task = TaskType(task) + self.image_metric_names = image_metrics + self.pixel_metric_names = pixel_metrics + self.device = device + + def setup( + self, + trainer: Trainer, + pl_module: AnomalyModule, + stage: str | None = None, + ) -> None: + """Set image and pixel-level AnomalibMetricsCollection within Anomalib Model. + + Args: + trainer (pl.Trainer): PyTorch Lightning Trainer + pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. + stage (str | None, optional): fit, validate, test or predict. Defaults to None. + """ + del stage, trainer # this variable is not used. + image_metric_names = [] if self.image_metric_names is None else self.image_metric_names + if isinstance(image_metric_names, str): + image_metric_names = [image_metric_names] + + pixel_metric_names: list[str] | dict[str, dict[str, Any]] + if self.pixel_metric_names is None: + pixel_metric_names = [] + elif self.task == TaskType.CLASSIFICATION: + pixel_metric_names = [] + logger.warning( + "Cannot perform pixel-level evaluation when task type is classification. " + "Ignoring the following pixel-level metrics: %s", + self.pixel_metric_names, + ) + else: + pixel_metric_names = ( + self.pixel_metric_names.copy() + if not isinstance(self.pixel_metric_names, str) + else [self.pixel_metric_names] + ) + + # create a separate metric collection for metrics that operate over the semantic segmentation mask + # (segmentation mask with a separate channel for each defect type) + semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = [] + # currently only SPRO metric is supported as semantic segmentation metric + if "SPRO" in pixel_metric_names: + if isinstance(pixel_metric_names, list): + pixel_metric_names.remove("SPRO") + semantic_pixel_metric_names = ["SPRO"] + elif isinstance(pixel_metric_names, dict): + spro_metric = pixel_metric_names.pop("SPRO") + semantic_pixel_metric_names = {"SPRO": spro_metric} + else: + logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names)) + + if isinstance(pl_module, AnomalyModule): + pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") + if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint + new_metrics = create_metric_collection(pixel_metric_names) + for name in new_metrics: + if name not in pl_module.pixel_metrics: + pl_module.pixel_metrics.add_metrics(new_metrics[name]) + else: + pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") + pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_") + self._set_threshold(pl_module) + + def on_validation_epoch_start( + self, + trainer: Trainer, + pl_module: AnomalyModule, + ) -> None: + del trainer # Unused argument. + + pl_module.image_metrics.reset() + pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del trainer, batch, batch_idx, dataloader_idx # Unused arguments. + + if outputs is not None: + self._outputs_to_device(outputs) + self._update_metrics(pl_module, outputs) + + def on_validation_epoch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + ) -> None: + del trainer # Unused argument. + + self._set_threshold(pl_module) + self._log_metrics(pl_module) + + def on_test_epoch_start( + self, + trainer: Trainer, + pl_module: AnomalyModule, + ) -> None: + del trainer # Unused argument. + + pl_module.image_metrics.reset() + pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del trainer, batch, batch_idx, dataloader_idx # Unused arguments. + + if outputs is not None: + self._outputs_to_device(outputs) + self._update_metrics(pl_module, outputs) + + def on_test_epoch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + ) -> None: + del trainer # Unused argument. + + self._log_metrics(pl_module) + + def _set_threshold(self, pl_module: AnomalyModule) -> None: + pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) + pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) + pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) + + def _update_metrics( + self, + pl_module: AnomalyModule, + output: STEP_OUTPUT, + ) -> None: + pl_module.image_metrics.to(self.device) + pl_module.image_metrics.update(output["pred_scores"], output["label"].int()) + if "mask" in output and "anomaly_maps" in output: + pl_module.pixel_metrics.to(self.device) + pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + if "semantic_mask" in output and "anomaly_maps" in output: + pl_module.semantic_pixel_metrics.to(self.device) + pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) + + def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: + if isinstance(output, dict): + for key, value in output.items(): + output[key] = self._outputs_to_device(value) + elif isinstance(output, torch.Tensor): + output = output.to(self.device) + elif isinstance(output, list): + for i, value in enumerate(output): + output[i] = self._outputs_to_device(value) + return output + + @staticmethod + def _log_metrics(pl_module: AnomalyModule) -> None: + """Log computed performance metrics.""" + pl_module.log_dict(pl_module.image_metrics, prog_bar=True) + if pl_module.pixel_metrics.update_called: + pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False) + if pl_module.semantic_pixel_metrics.update_called: + pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False) diff --git a/anomalib/callbacks/model_loader.py b/anomalib/callbacks/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f89958d77a1323317d10831617d2caedda2945e --- /dev/null +++ b/anomalib/callbacks/model_loader.py @@ -0,0 +1,39 @@ +"""Callback that loads model weights from the state dict.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +import torch +from lightning.pytorch import Callback, Trainer + +from anomalib.models.components import AnomalyModule + +logger = logging.getLogger(__name__) + + +class LoadModelCallback(Callback): + """Callback that loads the model weights from the state dict. + + Examples: + >>> from anomalib.callbacks import LoadModelCallback + >>> from anomalib.engine import Engine + ... + >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")] + >>> engine = Engine(callbacks=callbacks) + """ + + def __init__(self, weights_path: str) -> None: + self.weights_path = weights_path + + def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None: + """Call when inference begins. + + Loads the model weights from ``weights_path`` into the PyTorch module. + """ + del trainer, stage # These variables are not used. + + logger.info("Loading the model from %s", self.weights_path) + pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"]) diff --git a/anomalib/callbacks/nncf/__init__.py b/anomalib/callbacks/nncf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074a1bd8612a096baf8612d39a16a1ada68bfd70 --- /dev/null +++ b/anomalib/callbacks/nncf/__init__.py @@ -0,0 +1,4 @@ +"""Integration NNCF.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/anomalib/callbacks/nncf/callback.py b/anomalib/callbacks/nncf/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..4157ff2f5e27dd878dc202f3ccec23689bdf3aa5 --- /dev/null +++ b/anomalib/callbacks/nncf/callback.py @@ -0,0 +1,106 @@ +"""Callbacks for NNCF optimization.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import subprocess # nosec B404 +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import lightning.pytorch as pl +from lightning.pytorch import Callback +from nncf import NNCFConfig +from nncf.torch import register_default_init_args + +from anomalib.callbacks.nncf.utils import InitLoader, wrap_nncf_model + +if TYPE_CHECKING: + from nncf.api.compression import CompressionAlgorithmController + + +class NNCFCallback(Callback): + """Callback for NNCF compression. + + Assumes that the pl module contains a 'model' attribute, which is + the PyTorch module that must be compressed. + + Args: + config (dict): NNCF Configuration + export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved. + If None model will not be exported. + """ + + def __init__(self, config: dict, export_dir: str | None = None) -> None: + self.export_dir = export_dir + self.config = NNCFConfig(config) + self.nncf_ctrl: CompressionAlgorithmController | None = None + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None: + """Call when fit or test begins. + + Takes the pytorch model and wraps it using the compression controller + so that it is ready for nncf fine-tuning. + """ + del stage # `stage` variable is not used. + + if self.nncf_ctrl is not None: + return + + # Get validate subset to initialize quantization, + # because train subset does not contain anomalous images. + init_loader = InitLoader(trainer.datamodule.val_dataloader()) + config = register_default_init_args(self.config, init_loader) + + self.nncf_ctrl, pl_module.model = wrap_nncf_model( + model=pl_module.model, + config=config, + dataloader=trainer.datamodule.train_dataloader(), + init_state_dict=None, # type: ignore[arg-type] + ) + + def on_train_batch_start( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch: Any, # noqa: ANN401 + batch_idx: int, + unused: int = 0, + ) -> None: + """Call when the train batch begins. + + Prepare compression method to continue training the model in the next step. + """ + del trainer, pl_module, batch, batch_idx, unused # These variables are not used. + + if self.nncf_ctrl: + self.nncf_ctrl.scheduler.step() + + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Call when the train epoch starts. + + Prepare compression method to continue training the model in the next epoch. + """ + del trainer, pl_module # `trainer` and `pl_module` variables are not used. + + if self.nncf_ctrl: + self.nncf_ctrl.scheduler.epoch_step() + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Call when the train ends. + + Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. + """ + del trainer, pl_module # `trainer` and `pl_module` variables are not used. + + if self.export_dir is None or self.nncf_ctrl is None: + return + + Path(self.export_dir).mkdir(parents=True, exist_ok=True) + onnx_path = str(Path(self.export_dir) / "model_nncf.onnx") + self.nncf_ctrl.export_model(onnx_path) + + optimize_command = ["mo", "--input_model", onnx_path, "--output_dir", self.export_dir] + # TODO(samet-akcay): Check if mo can be done via python API + # CVS-122665 + subprocess.run(optimize_command, check=True) # noqa: S603 # nosec B603 diff --git a/anomalib/callbacks/nncf/utils.py b/anomalib/callbacks/nncf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a41c59ce1142e6f348c1847998d4ef3f70d0c71 --- /dev/null +++ b/anomalib/callbacks/nncf/utils.py @@ -0,0 +1,243 @@ +"""Utils for NNCf optimization.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from copy import copy +from typing import TYPE_CHECKING, Any + +import torch +from nncf import NNCFConfig +from nncf.api.compression import CompressionAlgorithmController +from nncf.torch import create_compressed_model, load_state, register_default_init_args +from nncf.torch.initialization import PTInitializingDataLoader +from nncf.torch.nncf_network import NNCFNetwork +from torch import nn +from torch.utils.data.dataloader import DataLoader + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(name="NNCF compression") + + +class InitLoader(PTInitializingDataLoader): + """Initializing data loader for NNCF to be used with unsupervised training algorithms.""" + + def __init__(self, data_loader: DataLoader) -> None: + super().__init__(data_loader) + self._data_loader_iter: Iterator + + def __iter__(self) -> "InitLoader": + """Create iterator for dataloader.""" + self._data_loader_iter = iter(self._data_loader) + return self + + def __next__(self) -> torch.Tensor: + """Return next item from dataloader iterator.""" + loaded_item = next(self._data_loader_iter) + return loaded_item["image"] + + def get_inputs(self, dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]: + """Get input to model. + + Returns: + (dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during + the initialization process + """ + return (dataloader_output,), {} + + def get_target(self, _): # noqa: ANN001, ANN201 + """Return structure for ground truth in loss criterion based on dataloader output. + + This implementation does not do anything and is a placeholder. + + Returns: + None + """ + return + + +def wrap_nncf_model( + model: nn.Module, + config: dict, + dataloader: DataLoader, + init_state_dict: dict, +) -> tuple[CompressionAlgorithmController, NNCFNetwork]: + """Wrap model by NNCF. + + :param model: Anomalib model. + :param config: NNCF config. + :param dataloader: Dataloader for initialization of NNCF model. + :param init_state_dict: Opti + :return: compression controller, compressed model + """ + nncf_config = NNCFConfig.from_dict(config) + + if not dataloader and not init_state_dict: + logger.warning( + "Either dataloader or NNCF pre-trained " + "model checkpoint should be set. Without this, " + "quantizers will not be initialized", + ) + + compression_state = None + resuming_state_dict = None + if init_state_dict: + resuming_state_dict = init_state_dict.get("model") + compression_state = init_state_dict.get("compression_state") + + if dataloader: + init_loader = InitLoader(dataloader) + nncf_config = register_default_init_args(nncf_config, init_loader) + + nncf_ctrl, nncf_model = create_compressed_model( + model=model, + config=nncf_config, + dump_graphs=False, + compression_state=compression_state, + ) + + if resuming_state_dict: + load_state(nncf_model, resuming_state_dict, is_resume=True) + + return nncf_ctrl, nncf_model + + +def is_state_nncf(state: dict) -> bool: + """Check if state is the result of NNCF-compressed model.""" + return bool(state.get("meta", {}).get("nncf_enable_compression", False)) + + +def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict: + """Compose NNCf config by selected options. + + :param nncf_config: + :param enabled_options: + :return: config + """ + optimisation_parts = nncf_config + optimisation_parts_to_choose = [] + if "order_of_parts" in optimisation_parts: + # The result of applying the changes from optimisation parts + # may depend on the order of applying the changes + # (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`, + # but for sparsity it is required `total_epochs=50`) + # So, user can define `order_of_parts` in the optimisation_config + # to specify the order of applying the parts. + order_of_parts = optimisation_parts["order_of_parts"] + if not isinstance(order_of_parts, list): + msg = 'The field "order_of_parts" in optimization config should be a list' + raise TypeError(msg) + + for part in enabled_options: + if part not in order_of_parts: + msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}" + raise ValueError(msg) + + optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options] + + if "base" not in optimisation_parts: + msg = 'Error: the optimisation config does not contain the "base" part' + raise KeyError(msg) + nncf_config_part = optimisation_parts["base"] + + for part in optimisation_parts_to_choose: + if part not in optimisation_parts: + msg = f'Error: the optimisation config does not contain the part "{part}"' + raise KeyError(msg) + optimisation_part_dict = optimisation_parts[part] + try: + nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict) + except AssertionError as cur_error: + err_descr = ( + f"Error during merging the parts of nncf configs:\n" + f"the current part={part}, " + f"the order of merging parts into base is {optimisation_parts_to_choose}.\n" + f"The error is:\n{cur_error}" + ) + raise RuntimeError(err_descr) from None + + return nncf_config_part + + +def merge_dicts_and_lists_b_into_a( + a: dict[Any, Any] | list[Any], + b: dict[Any, Any] | list[Any], +) -> dict[Any, Any] | list[Any]: + """Merge dict configs. + + Args: + a (dict[Any, Any] | list[Any]): First dict or list. + b (dict[Any, Any] | list[Any]): Second dict or list. + + Returns: + dict[Any, Any] | list[Any]: Merged dict or list. + """ + return _merge_dicts_and_lists_b_into_a(a, b, "") + + +def _merge_dicts_and_lists_b_into_a( + a: dict[Any, Any] | list[Any], + b: dict[Any, Any] | list[Any], + cur_key: int | str | None = None, +) -> dict[Any, Any] | list[Any]: + """Merge dict configs. + + * works with usual dicts and lists and derived types + * supports merging of lists (by concatenating the lists) + * makes recursive merging for dict + dict case + * overwrites when merging scalar into scalar + Note that we merge b into a (whereas Config makes merge a into b), + since otherwise the order of list merging is counter-intuitive. + + Args: + a (dict[Any, Any] | list[Any]): First dict or list. + b (dict[Any, Any] | list[Any]): Second dict or list. + cur_key (int | str | None, optional): key for current level of recursion. Defaults to None. + + Returns: + dict[Any, Any] | list[Any]: Merged dict or list. + """ + + def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str: + _key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`" + return ( + f"Error in merging parts of config: different types {_key_str}," + f" type(a) = {type(_a)}," + f" type(b) = {type(_b)}" + ) + + if not (isinstance(a, dict | list)): + msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}" + raise TypeError(msg) + + if not (isinstance(b, dict | list)): + raise TypeError(_err_str(a, b, cur_key)) + + if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)): + raise TypeError(_err_str(a, b, cur_key)) + + if isinstance(a, list) and isinstance(b, list): + # the main diff w.r.t. mmcf.Config -- merging of lists + return a + b + + a = copy(a) + for k in b: + if k not in a: + a[k] = copy(b[k]) + continue + new_cur_key = str(cur_key) + "." + k if cur_key else k + if isinstance(a[k], dict | list): + a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key) + continue + + if any(isinstance(b[k], t) for t in [dict, list]): + raise TypeError(_err_str(a[k], b[k], new_cur_key)) + + # suppose here that a[k] and b[k] are scalars, just overwrite + a[k] = b[k] + return a diff --git a/anomalib/callbacks/normalization/__init__.py b/anomalib/callbacks/normalization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a502b1aa5efce57de24724a434bf06f0e4c41390 --- /dev/null +++ b/anomalib/callbacks/normalization/__init__.py @@ -0,0 +1,12 @@ +"""Normalization callbacks. + +Note: These callbacks are used within the Engine. +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .min_max_normalization import _MinMaxNormalizationCallback +from .utils import get_normalization_callback + +__all__ = ["get_normalization_callback", "_MinMaxNormalizationCallback"] diff --git a/anomalib/callbacks/normalization/base.py b/anomalib/callbacks/normalization/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08129058894a87319e1c8565ad2340347a908a55 --- /dev/null +++ b/anomalib/callbacks/normalization/base.py @@ -0,0 +1,29 @@ +"""Base Normalization Callback.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + +from lightning.pytorch import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.models.components import AnomalyModule + + +class NormalizationCallback(Callback, ABC): + """Base normalization callback.""" + + @staticmethod + @abstractmethod + def _normalize_batch(batch: STEP_OUTPUT, pl_module: AnomalyModule) -> None: + """Normalize an output batch. + + Args: + batch (dict[str, torch.Tensor]): Output batch. + pl_module (AnomalyModule): AnomalyModule instance. + + Returns: + dict[str, torch.Tensor]: Normalized batch. + """ + raise NotImplementedError diff --git a/anomalib/callbacks/normalization/min_max_normalization.py b/anomalib/callbacks/normalization/min_max_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..f22a36afd07f128161869553e1b9c910f2ae49f6 --- /dev/null +++ b/anomalib/callbacks/normalization/min_max_normalization.py @@ -0,0 +1,109 @@ +"""Anomaly Score Normalization Callback that uses min-max normalization.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Any + +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.metrics import MinMax +from anomalib.models.components import AnomalyModule +from anomalib.utils.normalization.min_max import normalize + +from .base import NormalizationCallback + + +class _MinMaxNormalizationCallback(NormalizationCallback): + """Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization. + + Note: This callback is set within the Engine. + """ + + def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None: + """Add min_max metrics to normalization metrics.""" + del trainer, stage # These variables are not used. + + if not hasattr(pl_module, "normalization_metrics"): + pl_module.normalization_metrics = MinMax().cpu() + elif not isinstance(pl_module.normalization_metrics, MinMax): + msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}" + raise AttributeError( + msg, + ) + + def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + """Call when the test begins.""" + del trainer # `trainer` variable is not used. + + for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics): + if metric is not None: + metric.set_threshold(0.5) + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Call when the validation batch ends, update the min and max observed values.""" + del trainer, batch, batch_idx, dataloader_idx # These variables are not used. + + if "anomaly_maps" in outputs: + pl_module.normalization_metrics(outputs["anomaly_maps"]) + elif "box_scores" in outputs: + pl_module.normalization_metrics(torch.cat(outputs["box_scores"])) + elif "pred_scores" in outputs: + pl_module.normalization_metrics(outputs["pred_scores"]) + else: + msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores" + raise ValueError(msg) + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Call when the test batch ends, normalizes the predicted scores and anomaly maps.""" + del trainer, batch, batch_idx, dataloader_idx # These variables are not used. + + self._normalize_batch(outputs, pl_module) + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: Any, # noqa: ANN401 + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Call when the predict batch ends, normalizes the predicted scores and anomaly maps.""" + del trainer, batch, batch_idx, dataloader_idx # These variables are not used. + + self._normalize_batch(outputs, pl_module) + + @staticmethod + def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401 + """Normalize a batch of predictions.""" + image_threshold = pl_module.image_threshold.value.cpu() + pixel_threshold = pl_module.pixel_threshold.value.cpu() + stats = pl_module.normalization_metrics.cpu() + if "pred_scores" in outputs: + outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max) + if "anomaly_maps" in outputs: + outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max) + if "box_scores" in outputs: + outputs["box_scores"] = [ + normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"] + ] diff --git a/anomalib/callbacks/normalization/utils.py b/anomalib/callbacks/normalization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fca2d3f29d6383a945d13aa29939e0859b2fbf04 --- /dev/null +++ b/anomalib/callbacks/normalization/utils.py @@ -0,0 +1,78 @@ +"""Normalization callback utils.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib + +from lightning.pytorch import Callback +from omegaconf import DictConfig + +from anomalib.utils.normalization import NormalizationMethod +from anomalib.utils.types import NORMALIZATION + +from .min_max_normalization import _MinMaxNormalizationCallback + + +def get_normalization_callback( + normalization_method: NORMALIZATION = NormalizationMethod.MIN_MAX, +) -> Callback | None: + """Return normalization object. + + normalization_method is an instance of ``Callback``, it is returned as is. + + if normalization_method is of type ``NormalizationMethod``, then a new class is created based on the type of + normalization_method. + + Otherwise it expects a dictionary containing class_path and init_args. + normalization_method: + class_path: MinMaxNormalizer + init_args: + - + - + + Example: + >>> normalizer = get_normalization_callback(NormalizationMethod.MIN_MAX) + or + >>> normalizer = get_normalization_callback("min_max") + or + >>> normalizer = get_normalization_callback({"class_path": "MinMaxNormalizationCallback", "init_args": {}}) + or + >>> normalizer = get_normalization_callback(MinMaxNormalizationCallback()) + """ + normalizer: Callback | None + if isinstance(normalization_method, NormalizationMethod | str): + normalizer = _get_normalizer_from_method(NormalizationMethod(normalization_method)) + elif isinstance(normalization_method, Callback): + normalizer = normalization_method + elif isinstance(normalization_method, DictConfig): + normalizer = _parse_normalizer_config(normalization_method) + else: + msg = f"Unknown normalizer type {normalization_method}" + raise TypeError(msg) + return normalizer + + +def _get_normalizer_from_method(normalization_method: NormalizationMethod | str) -> Callback | None: + if normalization_method == NormalizationMethod.NONE: + normalizer = None + elif normalization_method == NormalizationMethod.MIN_MAX: + normalizer = _MinMaxNormalizationCallback() + else: + msg = f"Unknown normalization method {normalization_method}" + raise ValueError(msg) + return normalizer + + +def _parse_normalizer_config(normalization_method: DictConfig) -> Callback: + class_path = normalization_method.class_path + init_args = normalization_method.init_args + + if len(class_path.split(".")) == 1: + module_path = "anomalib.utils.callbacks.normalization" + else: + module_path = ".".join(class_path.split(".")[:-1]) + class_path = class_path.split(".")[-1] + module = importlib.import_module(module_path) + class_ = getattr(module, class_path) + return class_(**init_args) diff --git a/anomalib/callbacks/post_processor.py b/anomalib/callbacks/post_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c2aa251c6bc53e4db46526343983af90789bf2ea --- /dev/null +++ b/anomalib/callbacks/post_processor.py @@ -0,0 +1,125 @@ +"""Callback that attaches necessary pre/post-processing to the model.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Any + +import torch +from lightning import Callback +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes +from anomalib.models import AnomalyModule + + +class _PostProcessorCallback(Callback): + """Applies post-processing to the model outputs. + + Note: This callback is set within the Engine. + """ + + def __init__(self) -> None: + super().__init__() + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del batch, batch_idx, dataloader_idx # Unused arguments. + + if outputs is not None: + self.post_process(trainer, pl_module, outputs) + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del batch, batch_idx, dataloader_idx # Unused arguments. + + if outputs is not None: + self.post_process(trainer, pl_module, outputs) + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: Any, # noqa: ANN401 + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del batch, batch_idx, dataloader_idx # Unused arguments. + + if outputs is not None: + self.post_process(trainer, pl_module, outputs) + + def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None: + if isinstance(outputs, dict): + self._post_process(outputs) + if trainer.predicting or trainer.testing: + self._compute_scores_and_labels(pl_module, outputs) + + @staticmethod + def _compute_scores_and_labels( + pl_module: AnomalyModule, + outputs: dict[str, Any], + ) -> None: + if "pred_scores" in outputs: + outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value + if "anomaly_maps" in outputs: + outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value + if "pred_boxes" not in outputs: + outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes( + outputs["pred_masks"], + outputs["anomaly_maps"], + ) + outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]] + # apply thresholding to boxes + if "box_scores" in outputs and "box_labels" not in outputs: + # apply threshold to assign normal/anomalous label to boxes + is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]] + outputs["box_labels"] = [labels.int() for labels in is_anomalous] + + @staticmethod + def _post_process(outputs: STEP_OUTPUT) -> None: + """Compute labels based on model predictions.""" + if isinstance(outputs, dict): + if "pred_scores" not in outputs and "anomaly_maps" in outputs: + # infer image scores from anomaly maps + outputs["pred_scores"] = ( + outputs["anomaly_maps"] # noqa: PD011 + .reshape(outputs["anomaly_maps"].shape[0], -1) + .max(dim=1) + .values + ) + elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs: + # infer image score from bbox confidence scores + outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float() + for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)): + if boxes.numel(): + outputs["pred_scores"][idx] = scores.max().item() + + if "pred_boxes" in outputs and "anomaly_maps" not in outputs: + # create anomaly maps from bbox predictions for thresholding and evaluation + image_size: tuple[int, int] = outputs["image"].shape[-2:] + pred_boxes: torch.Tensor = outputs["pred_boxes"] + box_scores: torch.Tensor = outputs["box_scores"] + + outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size) + + if "boxes" in outputs: + true_boxes: list[torch.Tensor] = outputs["boxes"] + outputs["mask"] = boxes_to_masks(true_boxes, image_size) diff --git a/anomalib/callbacks/thresholding.py b/anomalib/callbacks/thresholding.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4d12febdb735336c6d56c8ef7e7a44515ba478 --- /dev/null +++ b/anomalib/callbacks/thresholding.py @@ -0,0 +1,197 @@ +"""Thresholding callback.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import Any + +import torch +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT +from omegaconf import DictConfig, ListConfig + +from anomalib.metrics.threshold import BaseThreshold +from anomalib.models import AnomalyModule +from anomalib.utils.types import THRESHOLD + + +class _ThresholdCallback(Callback): + """Setup/apply thresholding. + + Note: This callback is set within the Engine. + """ + + def __init__( + self, + threshold: THRESHOLD = "F1AdaptiveThreshold", + ) -> None: + super().__init__() + self._initialize_thresholds(threshold) + self.image_threshold: BaseThreshold + self.pixel_threshold: BaseThreshold + + def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str) -> None: + del trainer, stage # Unused arguments. + if not hasattr(pl_module, "image_threshold"): + pl_module.image_threshold = self.image_threshold + if not hasattr(pl_module, "pixel_threshold"): + pl_module.pixel_threshold = self.pixel_threshold + + def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + del trainer # Unused argument. + self._reset(pl_module) + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + del trainer, batch, batch_idx, dataloader_idx # Unused arguments. + if outputs is not None: + self._outputs_to_cpu(outputs) + self._update(pl_module, outputs) + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + del trainer # Unused argument. + self._compute(pl_module) + + def _initialize_thresholds( + self, + threshold: THRESHOLD, + ) -> None: + """Initialize ``self.image_threshold`` and ``self.pixel_threshold``. + + Args: + threshold (THRESHOLD): + Threshold configuration + + Example: + >>> _initialize_thresholds(F1AdaptiveThreshold()) + or + >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5))) + or configuration + + For more details on configuration see :fun:`_load_from_config` + + Raises: + ValueError: Unknown threshold class or incorrect configuration + """ + # TODO(djdameln): Add tests for each case + # CVS-122661 + # When only a single threshold class is passed. + # This initializes image and pixel thresholds with the same class + # >>> _initialize_thresholds(F1AdaptiveThreshold()) + if isinstance(threshold, BaseThreshold): + self.image_threshold = threshold + self.pixel_threshold = threshold.clone() + + # When a tuple of threshold classes are passed + # >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5))) + elif isinstance(threshold, tuple) and isinstance(threshold[0], BaseThreshold): + self.image_threshold = threshold[0] + self.pixel_threshold = threshold[1] + # When the passed threshold is not an instance of a Threshold class. + elif isinstance(threshold, str | DictConfig | ListConfig | list): + self._load_from_config(threshold) + else: + msg = f"Invalid threshold type {type(threshold)}" + raise TypeError(msg) + + def _load_from_config(self, threshold: DictConfig | str | ListConfig | list[dict[str, str | float]]) -> None: + """Load the thresholding class based on the config. + + Example: + threshold: F1AdaptiveThreshold + or + threshold: + class_path: F1AdaptiveThreshold + init_args: + - + or + threshold: + - F1AdaptiveThreshold + - F1AdaptiveThreshold + or + threshold: + - class_path: F1AdaptiveThreshold + init_args: + - + - class_path: F1AdaptiveThreshold + """ + if isinstance(threshold, str | DictConfig): + self.image_threshold = self._get_threshold_from_config(threshold) + self.pixel_threshold = self.image_threshold.clone() + elif isinstance(threshold, ListConfig | list): + self.image_threshold = self._get_threshold_from_config(threshold[0]) + self.pixel_threshold = self._get_threshold_from_config(threshold[1]) + else: + msg = f"Invalid threshold config {threshold}" + raise TypeError(msg) + + def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> BaseThreshold: + """Return the instantiated threshold object. + + Example: + >>> _get_threshold_from_config(F1AdaptiveThreshold) + or + >>> config = DictConfig({ + ... "class_path": "ManualThreshold", + ... "init_args": {"default_value": 0.7} + ... }) + >>> __get_threshold_from_config(config) + or + >>> config = DictConfig({ + ... "class_path": "anomalib.metrics.threshold.F1AdaptiveThreshold" + ... }) + >>> __get_threshold_from_config(config) + + Returns: + (BaseThreshold): Instance of threshold object. + """ + if isinstance(threshold, str): + threshold = DictConfig({"class_path": threshold}) + + class_path = threshold["class_path"] + init_args = threshold.get("init_args", {}) + + if len(class_path.split(".")) == 1: + module_path = "anomalib.metrics.threshold" + + else: + module_path = ".".join(class_path.split(".")[:-1]) + class_path = class_path.split(".")[-1] + + module = importlib.import_module(module_path) + class_ = getattr(module, class_path) + return class_(**init_args) + + def _reset(self, pl_module: AnomalyModule) -> None: + pl_module.image_threshold.reset() + pl_module.pixel_threshold.reset() + + def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: + if isinstance(output, dict): + for key, value in output.items(): + output[key] = self._outputs_to_cpu(value) + elif isinstance(output, torch.Tensor): + output = output.cpu() + return output + + def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None: + pl_module.image_threshold.cpu() + pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int()) + if "mask" in outputs and "anomaly_maps" in outputs: + pl_module.pixel_threshold.cpu() + pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int()) + + def _compute(self, pl_module: AnomalyModule) -> None: + pl_module.image_threshold.compute() + if pl_module.pixel_threshold._update_called: # noqa: SLF001 + pl_module.pixel_threshold.compute() + else: + pl_module.pixel_threshold.value = pl_module.image_threshold.value diff --git a/anomalib/callbacks/tiler_configuration.py b/anomalib/callbacks/tiler_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..193349711ad9c28bb1afbfc0cd668de0bde27ce8 --- /dev/null +++ b/anomalib/callbacks/tiler_configuration.py @@ -0,0 +1,74 @@ +"""Tiler Callback.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Sequence + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Callback + +from anomalib.data.utils.tiler import ImageUpscaleMode, Tiler +from anomalib.models.components import AnomalyModule + +__all__ = ["TilerConfigurationCallback"] + + +class TilerConfigurationCallback(Callback): + """Tiler Configuration Callback.""" + + def __init__( + self, + enable: bool = False, + tile_size: int | Sequence = 256, + stride: int | Sequence | None = None, + remove_border_count: int = 0, + mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, + ) -> None: + """Set tiling configuration from the command line. + + Args: + enable (bool): Boolean to enable tiling operation. + Defaults to False. + tile_size ([int | Sequence]): Tile size. + Defaults to 256. + stride ([int | Sequence]): Stride to move tiles on the image. + remove_border_count (int, optional): Number of pixels to remove from the image before + tiling. Defaults to 0. + mode (str, optional): Up-scaling mode when untiling overlapping tiles. + Defaults to "padding". + tile_count (SupportsIndex, optional): Number of random tiles to sample from the image. + Defaults to 4. + """ + self.enable = enable + self.tile_size = tile_size + self.stride = stride + self.remove_border_count = remove_border_count + self.mode = mode + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None: + """Set Tiler object within Anomalib Model. + + Args: + trainer (pl.Trainer): PyTorch Lightning Trainer + pl_module (pl.LightningModule): Anomalib Model that inherits pl LightningModule. + stage (str | None, optional): fit, validate, test or predict. Defaults to None. + + Raises: + ValueError: When Anomalib Model doesn't contain ``Tiler`` object, it means the model + doesn not support tiling operation. + """ + del trainer, stage # These variables are not used. + + if self.enable: + if isinstance(pl_module, AnomalyModule) and hasattr(pl_module.model, "tiler"): + pl_module.model.tiler = Tiler( + tile_size=self.tile_size, + stride=self.stride, + remove_border_count=self.remove_border_count, + mode=self.mode, + ) + else: + msg = "Model does not support tiling." + raise ValueError(msg) diff --git a/anomalib/callbacks/timer.py b/anomalib/callbacks/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9658a9b0c6c8a67d3445e20b310272e1219634 --- /dev/null +++ b/anomalib/callbacks/timer.py @@ -0,0 +1,109 @@ +"""Callback to measure training and testing time of a PyTorch Lightning module.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import time + +import torch +from lightning.pytorch import Callback, LightningModule, Trainer + +logger = logging.getLogger(__name__) + + +class TimerCallback(Callback): + """Callback that measures the training and testing time of a PyTorch Lightning module. + + Examples: + >>> from anomalib.callbacks import TimerCallback + >>> from anomalib.engine import Engine + ... + >>> callbacks = [TimerCallback()] + >>> engine = Engine(callbacks=callbacks) + """ + + def __init__(self) -> None: + self.start: float + self.num_images: int = 0 + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Call when fit begins. + + Sets the start time to the time training started. + + Args: + trainer (Trainer): PyTorch Lightning trainer. + pl_module (LightningModule): Current training module. + + Returns: + None + """ + del trainer, pl_module # These variables are not used. + + self.start = time.time() + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Call when fit ends. + + Prints the time taken for training. + + Args: + trainer (Trainer): PyTorch Lightning trainer. + pl_module (LightningModule): Current training module. + + Returns: + None + """ + del trainer, pl_module # Unused arguments. + logger.info("Training took %5.2f seconds", (time.time() - self.start)) + + def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Call when the test begins. + + Sets the start time to the time testing started. + Goes over all the test dataloaders and adds the number of images in each. + + Args: + trainer (Trainer): PyTorch Lightning trainer. + pl_module (LightningModule): Current training module. + + Returns: + None + """ + del pl_module # Unused argument. + + self.start = time.time() + self.num_images = 0 + + if trainer.test_dataloaders is not None: # Check to placate Mypy. + if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader): + self.num_images += len(trainer.test_dataloaders.dataset) + else: + for dataloader in trainer.test_dataloaders: + self.num_images += len(dataloader.dataset) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Call when the test ends. + + Prints the time taken for testing and the throughput in frames per second. + + Args: + trainer (Trainer): PyTorch Lightning trainer. + pl_module (LightningModule): Current training module. + + Returns: + None + """ + del pl_module # Unused argument. + + testing_time = time.time() - self.start + output = f"Testing took {testing_time} seconds\nThroughput " + if trainer.test_dataloaders is not None: + if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader): + test_data_loader = trainer.test_dataloaders + else: + test_data_loader = trainer.test_dataloaders[0] + output += f"(batch_size={test_data_loader.batch_size})" + output += f" : {self.num_images/testing_time} FPS" + logger.info(output) diff --git a/anomalib/callbacks/visualizer.py b/anomalib/callbacks/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c78c1f6ab823aab2b4b0bd58c1b0af11e9b03b3f --- /dev/null +++ b/anomalib/callbacks/visualizer.py @@ -0,0 +1,182 @@ +"""Visualizer Callback. + +This is assigned by Anomalib Engine internally. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from typing import Any, cast + +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.data.utils.image import save_image, show_image +from anomalib.loggers import AnomalibWandbLogger +from anomalib.loggers.base import ImageLoggerBase +from anomalib.models import AnomalyModule +from anomalib.utils.visualization import ( + BaseVisualizer, + GeneratorResult, + VisualizationStep, +) + +logger = logging.getLogger(__name__) + + +class _VisualizationCallback(Callback): + """Callback for visualization that is used internally by the Engine. + + Args: + visualizers (BaseVisualizer | list[BaseVisualizer]): + Visualizer objects that are used for computing the visualizations. Defaults to None. + save (bool, optional): Save the image. Defaults to False. + root (Path | None, optional): The path to save the images. Defaults to None. + log (bool, optional): Log the images into the loggers. Defaults to False. + show (bool, optional): Show the images. Defaults to False. + + Example: + >>> visualizers = [ImageVisualizer(), MetricsVisualizer()] + >>> visualization_callback = _VisualizationCallback( + ... visualizers=visualizers, + ... save=True, + ... root="results/images" + ... ) + + CLI + $ anomalib train --model Padim --data MVTec \ + --visualization.visualizers ImageVisualizer \ + --visualization.visualizers+=MetricsVisualizer + or + $ anomalib train --model Padim --data MVTec \ + --visualization.visualizers '[ImageVisualizer, MetricsVisualizer]' + + Raises: + ValueError: Incase `root` is None and `save` is True. + """ + + def __init__( + self, + visualizers: BaseVisualizer | list[BaseVisualizer], + save: bool = False, + root: Path | None = None, + log: bool = False, + show: bool = False, + ) -> None: + self.save = save + if save and root is None: + msg = "`root` must be provided if save is True" + raise ValueError(msg) + self.root: Path = root if root is not None else Path() # need this check for mypy + self.log = log + self.show = show + self.generators = visualizers if isinstance(visualizers, list) else [visualizers] + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + for generator in self.generators: + if generator.visualize_on == VisualizationStep.BATCH: + for result in generator( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ): + if self.save: + if result.file_name is None: + msg = "``save`` is set to ``True`` but file name is ``None``" + raise ValueError(msg) + + # Get the filename to save the image. + # Filename is split based on the datamodule name and category. + # For example, if the filename is `MVTec/bottle/000.png`, then the + # filename is split based on `MVTec/bottle` and `000.png` is saved. + if trainer.datamodule is not None: + filename = str(result.file_name).split( + sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}", + )[-1] + else: + filename = Path(result.file_name).name + save_image(image=result.image, root=self.root, filename=filename) + if self.show: + show_image(image=result.image, title=str(result.file_name)) + if self.log: + self._add_to_logger(result, pl_module, trainer) + + def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + for generator in self.generators: + if generator.visualize_on == VisualizationStep.STAGE_END: + for result in generator(trainer=trainer, pl_module=pl_module): + if self.save: + if result.file_name is None: + msg = "``save`` is set to ``True`` but file name is ``None``" + raise ValueError(msg) + save_image(image=result.image, root=self.root, filename=result.file_name) + if self.show: + show_image(image=result.image, title=str(result.file_name)) + if self.log: + self._add_to_logger(result, pl_module, trainer) + + for logger in trainer.loggers: + if isinstance(logger, AnomalibWandbLogger): + logger.save() + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: AnomalyModule, + outputs: STEP_OUTPUT | None, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + return self.on_test_end(trainer, pl_module) + + def _add_to_logger( + self, + result: GeneratorResult, + module: AnomalyModule, + trainer: Trainer, + ) -> None: + """Add image to logger. + + Args: + result (GeneratorResult): Output from the generators. + module (AnomalyModule): LightningModule from which the global step is extracted. + trainer (Trainer): Trainer object. + """ + # Store names of logger and the logger in a dict + available_loggers = { + type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger + for logger in trainer.loggers + } + # save image to respective logger + if result.file_name is None: + msg = "File name is None" + raise ValueError(msg) + filename = result.file_name + image = result.image + for log_to in available_loggers: + # check if logger object is same as the requested object + if isinstance(available_loggers[log_to], ImageLoggerBase): + logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy + _name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename + logger.add_image( + image=image, + name=_name, + global_step=module.global_step, + ) diff --git a/anomalib/cli/__init__.py b/anomalib/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78b54e5988e97bdd250cbc001ba837949df6f687 --- /dev/null +++ b/anomalib/cli/__init__.py @@ -0,0 +1,8 @@ +"""Anomalib CLI.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .cli import AnomalibCLI + +__all__ = ["AnomalibCLI"] diff --git a/anomalib/cli/cli.py b/anomalib/cli/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..210a96aa16099fb7e885923be977fb2bb7a233ea --- /dev/null +++ b/anomalib/cli/cli.py @@ -0,0 +1,483 @@ +"""Anomalib CLI.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Callable, Sequence +from functools import partial +from pathlib import Path +from types import MethodType +from typing import Any + +from jsonargparse import ActionConfigFile, ArgumentParser, Namespace +from jsonargparse._actions import _ActionSubCommands +from rich import traceback + +from anomalib import TaskType, __version__ +from anomalib.cli.utils.help_formatter import CustomHelpFormatter, get_short_docstring +from anomalib.cli.utils.openvino import add_openvino_export_arguments +from anomalib.loggers import configure_logger + +traceback.install() +logger = logging.getLogger("anomalib.cli") + +_LIGHTNING_AVAILABLE = True +try: + from lightning.pytorch import Trainer + from torch.utils.data import DataLoader, Dataset + + from anomalib.data import AnomalibDataModule + from anomalib.engine import Engine + from anomalib.metrics.threshold import BaseThreshold + from anomalib.models import AnomalyModule + from anomalib.utils.config import update_config + +except ImportError: + _LIGHTNING_AVAILABLE = False + + +class AnomalibCLI: + """Implementation of a fully configurable CLI tool for anomalib. + + The advantage of this tool is its flexibility to configure the pipeline + from both the CLI and a configuration file (.yaml or .json). It is even + possible to use both the CLI and a configuration file simultaneously. + For more details, the reader could refer to PyTorch Lightning CLI + documentation. + + ``save_config_kwargs`` is set to ``overwrite=True`` so that the + ``SaveConfigCallback`` overwrites the config if it already exists. + """ + + def __init__(self, args: Sequence[str] | None = None) -> None: + self.parser = self.init_parser() + self.subcommand_parsers: dict[str, ArgumentParser] = {} + self.subcommand_method_arguments: dict[str, list[str]] = {} + self.add_subcommands() + self.config = self.parser.parse_args(args=args) + self.subcommand = self.config["subcommand"] + if _LIGHTNING_AVAILABLE: + self.before_instantiate_classes() + self.instantiate_classes() + self._run_subcommand() + + def init_parser(self, **kwargs) -> ArgumentParser: + """Method that instantiates the argument parser.""" + kwargs.setdefault("dump_header", [f"anomalib=={__version__}"]) + parser = ArgumentParser(formatter_class=CustomHelpFormatter, **kwargs) + parser.add_argument( + "-c", + "--config", + action=ActionConfigFile, + help="Path to a configuration file in json or yaml format.", + ) + return parser + + @staticmethod + def subcommands() -> dict[str, set[str]]: + """Skip predict subcommand as it is added later.""" + return { + "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, + "validate": {"model", "dataloaders", "datamodule"}, + "test": {"model", "dataloaders", "datamodule"}, + } + + @staticmethod + def anomalib_subcommands() -> dict[str, dict[str, str]]: + """Return a dictionary of subcommands and their description.""" + return { + "train": {"description": "Fit the model and then call test on the trained model."}, + "predict": {"description": "Run inference on a model."}, + "export": {"description": "Export the model to ONNX or OpenVINO format."}, + } + + def add_subcommands(self, **kwargs) -> None: + """Initialize base subcommands and add anomalib specific on top of it.""" + parser_subcommands = self.parser.add_subcommands() + + # Extra subcommand: install + self._set_install_subcommand(parser_subcommands) + + if not _LIGHTNING_AVAILABLE: + # If environment is not configured to use pl, do not add a subcommand for Engine. + return + + # Add Trainer subcommands + for subcommand in self.subcommands(): + sub_parser = self.init_parser(**kwargs) + + fn = getattr(Trainer, subcommand) + # extract the first line description in the docstring for the subcommand help message + description = get_short_docstring(fn) + subparser_kwargs = kwargs.get(subcommand, {}) + subparser_kwargs.setdefault("description", description) + + self.subcommand_parsers[subcommand] = sub_parser + parser_subcommands.add_subcommand(subcommand, sub_parser, help=description) + self.add_trainer_arguments(sub_parser, subcommand) + + # Add anomalib subcommands + for subcommand in self.anomalib_subcommands(): + sub_parser = self.init_parser(**kwargs) + + self.subcommand_parsers[subcommand] = sub_parser + parser_subcommands.add_subcommand( + subcommand, + sub_parser, + help=self.anomalib_subcommands()[subcommand]["description"], + ) + # add arguments to subcommand + getattr(self, f"add_{subcommand}_arguments")(sub_parser) + + def add_arguments_to_parser(self, parser: ArgumentParser) -> None: + """Extend trainer's arguments to add engine arguments. + + .. note:: + Since ``Engine`` parameters are manually added, any change to the + ``Engine`` class should be reflected manually. + """ + from anomalib.callbacks.normalization import get_normalization_callback + + parser.add_function_arguments(get_normalization_callback, "normalization") + parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) + parser.add_argument( + "--metrics.image", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=["F1Score", "AUROC"], + ) + parser.add_argument( + "--metrics.pixel", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=None, + required=False, + ) + parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold") + parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False) + if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"): + parser.link_arguments("task", "data.init_args.task") + parser.add_argument( + "--default_root_dir", + type=Path, + help="Path to save the results.", + default=Path("./results"), + ) + parser.link_arguments("default_root_dir", "trainer.default_root_dir") + # TODO(ashwinvaidya17): Tiling should also be a category of its own + # CVS-122659 + + def add_trainer_arguments(self, parser: ArgumentParser, subcommand: str) -> None: + """Add train arguments to the parser.""" + self._add_default_arguments_to_parser(parser) + self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True) + parser.add_subclass_arguments( + AnomalyModule, + "model", + fail_untyped=False, + required=True, + ) + parser.add_subclass_arguments(AnomalibDataModule, "data") + self.add_arguments_to_parser(parser) + skip: set[str | int] = set(self.subcommands()[subcommand]) + added = parser.add_method_arguments( + Trainer, + subcommand, + skip=skip, + ) + self.subcommand_method_arguments[subcommand] = added + + def add_train_arguments(self, parser: ArgumentParser) -> None: + """Add train arguments to the parser.""" + self._add_default_arguments_to_parser(parser) + self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True) + parser.add_subclass_arguments( + AnomalyModule, + "model", + fail_untyped=False, + required=True, + ) + parser.add_subclass_arguments(AnomalibDataModule, "data") + self.add_arguments_to_parser(parser) + added = parser.add_method_arguments( + Engine, + "train", + skip={"model", "datamodule", "val_dataloaders", "test_dataloaders", "train_dataloaders"}, + ) + self.subcommand_method_arguments["train"] = added + + def add_predict_arguments(self, parser: ArgumentParser) -> None: + """Add predict arguments to the parser.""" + self._add_default_arguments_to_parser(parser) + self._add_trainer_arguments_to_parser(parser) + parser.add_subclass_arguments( + AnomalyModule, + "model", + fail_untyped=False, + required=True, + ) + parser.add_argument( + "--data", + type=Dataset | AnomalibDataModule | DataLoader | str | Path, + required=True, + ) + added = parser.add_method_arguments( + Engine, + "predict", + skip={"model", "dataloaders", "datamodule", "dataset", "data_path"}, + ) + self.subcommand_method_arguments["predict"] = added + self.add_arguments_to_parser(parser) + + def add_export_arguments(self, parser: ArgumentParser) -> None: + """Add export arguments to the parser.""" + self._add_default_arguments_to_parser(parser) + self._add_trainer_arguments_to_parser(parser) + parser.add_subclass_arguments( + AnomalyModule, + "model", + fail_untyped=False, + required=True, + ) + added = parser.add_method_arguments( + Engine, + "export", + skip={"ov_args", "model"}, + ) + self.subcommand_method_arguments["export"] = added + add_openvino_export_arguments(parser) + self.add_arguments_to_parser(parser) + + def _set_install_subcommand(self, action_subcommand: _ActionSubCommands) -> None: + sub_parser = ArgumentParser(formatter_class=CustomHelpFormatter) + sub_parser.add_argument( + "--option", + help="Install the full or optional-dependencies.", + default="full", + type=str, + choices=["full", "core", "dev", "loggers", "notebooks", "openvino"], + ) + sub_parser.add_argument( + "-v", + "--verbose", + help="Set Logger level to INFO", + action="store_true", + ) + + self.subcommand_parsers["install"] = sub_parser + action_subcommand.add_subcommand( + "install", + sub_parser, + help="Install the full-package for anomalib.", + ) + + def before_instantiate_classes(self) -> None: + """Modify the configuration to properly instantiate classes and sets up tiler.""" + subcommand = self.config["subcommand"] + if subcommand in (*self.subcommands(), "train", "predict"): + self.config[subcommand] = update_config(self.config[subcommand]) + + def instantiate_classes(self) -> None: + """Instantiate classes depending on the subcommand. + + For trainer related commands it instantiates all the model, datamodule and trainer classes. + But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc + This is because the subcommand is responsible for instantiating and executing code based on the passed config + """ + if self.config["subcommand"] in (*self.subcommands(), "predict"): # trainer commands + # since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object. + # the minor change here is that engine is instantiated instead of trainer + self.config_init = self.parser.instantiate_classes(self.config) + self.datamodule = self._get(self.config_init, "data") + if isinstance(self.datamodule, Dataset): + self.datamodule = DataLoader(self.datamodule) + self.model = self._get(self.config_init, "model") + self._configure_optimizers_method_to_model() + self.instantiate_engine() + else: + self.config_init = self.parser.instantiate_classes(self.config) + subcommand = self.config["subcommand"] + if subcommand in ("train", "export"): + self.instantiate_engine() + if "model" in self.config_init[subcommand]: + self.model = self._get(self.config_init, "model") + else: + self.model = None + if "data" in self.config_init[subcommand]: + self.datamodule = self._get(self.config_init, "data") + else: + self.datamodule = None + + def instantiate_engine(self) -> None: + """Instantiate the engine. + + .. note:: + Most of the code in this method is taken from ``LightningCLI``'s + ``instantiate_trainer`` method. Refer to that method for more + details. + """ + from lightning.pytorch.cli import SaveConfigCallback + + from anomalib.callbacks import get_callbacks + + engine_args = { + "normalization": self._get(self.config_init, "normalization.normalization_method"), + "threshold": self._get(self.config_init, "metrics.threshold"), + "task": self._get(self.config_init, "task"), + "image_metrics": self._get(self.config_init, "metrics.image"), + "pixel_metrics": self._get(self.config_init, "metrics.pixel"), + } + trainer_config = {**self._get(self.config_init, "trainer", default={}), **engine_args} + key = "callbacks" + if key in trainer_config: + if trainer_config[key] is None: + trainer_config[key] = [] + elif not isinstance(trainer_config[key], list): + trainer_config[key] = [trainer_config[key]] + if not trainer_config.get("fast_dev_run", False): + config_callback = SaveConfigCallback( + self._parser(self.subcommand), + self.config.get(str(self.subcommand), self.config), + overwrite=True, + ) + trainer_config[key].append(config_callback) + trainer_config[key].extend(get_callbacks(self.config[self.subcommand])) + self.engine = Engine(**trainer_config) + + def _run_subcommand(self) -> None: + """Run subcommand depending on the subcommand. + + This overrides the original ``_run_subcommand`` to run the ``Engine`` + method rather than the ``Train`` method. + """ + if self.subcommand == "install": + from anomalib.cli.install import anomalib_install + + install_kwargs = self.config.get("install", {}) + anomalib_install(**install_kwargs) + elif self.config["subcommand"] in (*self.subcommands(), "train", "export", "predict"): + fn = getattr(self.engine, self.subcommand) + fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand) + fn(**fn_kwargs) + else: + self.config_init = self.parser.instantiate_classes(self.config) + getattr(self, f"{self.subcommand}")() + + @property + def fit(self) -> Callable: + """Fit the model using engine's fit method.""" + return self.engine.fit + + @property + def validate(self) -> Callable: + """Validate the model using engine's validate method.""" + return self.engine.validate + + @property + def test(self) -> Callable: + """Test the model using engine's test method.""" + return self.engine.test + + @property + def predict(self) -> Callable: + """Predict using engine's predict method.""" + return self.engine.predict + + @property + def train(self) -> Callable: + """Train the model using engine's train method.""" + return self.engine.train + + @property + def export(self) -> Callable: + """Export the model using engine's export method.""" + return self.engine.export + + def _add_trainer_arguments_to_parser( + self, + parser: ArgumentParser, + add_optimizer: bool = False, + add_scheduler: bool = False, + ) -> None: + """Add trainer arguments to the parser.""" + parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True) + + if add_optimizer: + from torch.optim import Optimizer + + optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + parser.add_subclass_arguments( + baseclass=(Optimizer,), + nested_key="optimizer", + **optim_kwargs, + ) + if add_scheduler: + from lightning.pytorch.cli import LRSchedulerTypeTuple + + scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + parser.add_subclass_arguments( + baseclass=LRSchedulerTypeTuple, + nested_key="lr_scheduler", + **scheduler_kwargs, + ) + + def _add_default_arguments_to_parser(self, parser: ArgumentParser) -> None: + """Adds default arguments to the parser.""" + parser.add_argument( + "--seed_everything", + type=bool | int, + default=True, + help=( + "Set to an int to run seed_everything with this value before classes instantiation." + "Set to True to use a random seed." + ), + ) + + def _get(self, config: Namespace, key: str, default: Any = None) -> Any: # noqa: ANN401 + """Utility to get a config value which might be inside a subcommand.""" + return config.get(str(self.subcommand), config).get(key, default) + + def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: + """Prepares the keyword arguments to pass to the subcommand to run.""" + fn_kwargs = { + k: v for k, v in self.config_init[subcommand].items() if k in self.subcommand_method_arguments[subcommand] + } + fn_kwargs["model"] = self.model + if self.datamodule is not None: + if isinstance(self.datamodule, AnomalibDataModule): + fn_kwargs["datamodule"] = self.datamodule + elif isinstance(self.datamodule, DataLoader): + fn_kwargs["dataloaders"] = self.datamodule + elif isinstance(self.datamodule, Path | str): + fn_kwargs["data_path"] = self.datamodule + return fn_kwargs + + def _parser(self, subcommand: str | None) -> ArgumentParser: + if subcommand is None: + return self.parser + # return the subcommand parser for the subcommand passed + return self.subcommand_parsers[subcommand] + + def _configure_optimizers_method_to_model(self) -> None: + from lightning.pytorch.cli import LightningCLI, instantiate_class + + optimizer_cfg = self._get(self.config_init, "optimizer", None) + if optimizer_cfg is None: + return + lr_scheduler_cfg = self._get(self.config_init, "lr_scheduler", {}) + + optimizer = instantiate_class(self.model.parameters(), optimizer_cfg) + lr_scheduler = instantiate_class(optimizer, lr_scheduler_cfg) if lr_scheduler_cfg else None + fn = partial(LightningCLI.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) + + # override the existing method + self.model.configure_optimizers = MethodType(fn, self.model) + + +def main() -> None: + """Trainer via Anomalib CLI.""" + configure_logger() + AnomalibCLI() + + +if __name__ == "__main__": + main() diff --git a/anomalib/cli/install.py b/anomalib/cli/install.py new file mode 100644 index 0000000000000000000000000000000000000000..31432be487c7ea2fa61f3aa1ea2e2c25f781b393 --- /dev/null +++ b/anomalib/cli/install.py @@ -0,0 +1,81 @@ +"""Anomalib install subcommand code.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from pkg_resources import Requirement +from rich.console import Console +from rich.logging import RichHandler + +from anomalib.cli.utils.installation import ( + get_requirements, + get_torch_install_args, + parse_requirements, +) + +logger = logging.getLogger("pip") +logger.setLevel(logging.WARNING) # setLevel: CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET +console = Console() +handler = RichHandler( + console=console, + show_level=False, + show_path=False, +) +logger.addHandler(handler) + + +def anomalib_install(option: str = "full", verbose: bool = False) -> int: + """Install Anomalib requirements. + + Args: + option (str | None): Optional-dependency to install requirements for. + verbose (bool): Set pip logger level to INFO + + Raises: + ValueError: When the task is not supported. + + Returns: + int: Status code of the pip install command. + """ + from pip._internal.commands import create_command + + requirements_dict = get_requirements("anomalib") + + requirements = [] + if option == "full": + for extra in requirements_dict: + requirements.extend(requirements_dict[extra]) + elif option in requirements_dict: + requirements.extend(requirements_dict[option]) + elif option is not None: + requirements.append(Requirement.parse(option)) + + # Parse requirements into torch and other requirements. + # This is done to parse the correct version of torch (cpu/cuda). + torch_requirement, other_requirements = parse_requirements(requirements, skip_torch=option not in ("full", "core")) + + # Get install args for torch to install it from a specific index-url + install_args: list[str] = [] + torch_install_args = [] + if option in ("full", "core") and torch_requirement is not None: + torch_install_args = get_torch_install_args(torch_requirement) + + # Combine torch and other requirements. + install_args = other_requirements + torch_install_args + + # Install requirements. + with console.status("[bold green]Installing packages... This may take a few minutes.\n") as status: + if verbose: + logger.setLevel(logging.INFO) + status.stop() + console.log(f"Installation list: [yellow]{install_args}[/yellow]") + status_code = create_command("install").main(install_args) + if status_code == 0: + console.log(f"Installation Complete: {install_args}") + + if status_code == 0: + console.print("Anomalib Installation [bold green]Complete.[/bold green]") + + return status_code diff --git a/anomalib/cli/utils/__init__.py b/anomalib/cli/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..028c9727287d47568029d3febe33f33cdc9ee17a --- /dev/null +++ b/anomalib/cli/utils/__init__.py @@ -0,0 +1,8 @@ +"""Anomalib CLI Utils.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .help_formatter import CustomHelpFormatter + +__all__ = ["CustomHelpFormatter"] diff --git a/anomalib/cli/utils/help_formatter.py b/anomalib/cli/utils/help_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6c89b6a246b8c8a4676750dfcb32e6555fad35 --- /dev/null +++ b/anomalib/cli/utils/help_formatter.py @@ -0,0 +1,268 @@ +"""Custom Help Formatters for Anomalib CLI.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import re +import sys +from typing import TypeVar + +import docstring_parser +from jsonargparse import DefaultHelpFormatter +from rich.markdown import Markdown +from rich.panel import Panel +from rich_argparse import RichHelpFormatter + +REQUIRED_ARGUMENTS = { + "train": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "fit": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "validate": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "test": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "predict": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "export": {"model", "model.help", "export_type", "ckpt_path", "config"}, +} + +try: + from anomalib.engine import Engine + + DOCSTRING_USAGE = { + "train": Engine.train, + "fit": Engine.fit, + "validate": Engine.validate, + "test": Engine.test, + "predict": Engine.predict, + "export": Engine.export, + } +except ImportError: + print("To use other subcommand using `anomalib install`") + + +def get_short_docstring(component: TypeVar) -> str: + """Get the short description from the docstring. + + Args: + component (TypeVar): The component to get the docstring from + + Returns: + str: The short description + """ + if component.__doc__ is None: + return "" + docstring = docstring_parser.parse(component.__doc__) + return docstring.short_description + + +def get_verbosity_subcommand() -> dict: + """Parse command line arguments and returns a dictionary of key-value pairs. + + Returns: + A dictionary containing the parsed command line arguments. + + Examples: + >>> import sys + >>> sys.argv = ['anomalib', 'train', '-h', '-v'] + >>> get_verbosity_subcommand() + {'subcommand': 'train', 'help': True, 'verbosity': 1} + """ + arguments: dict = {"subcommand": None, "help": False, "verbosity": 2} + if len(sys.argv) >= 2 and sys.argv[1] not in ("--help", "-h"): + arguments["subcommand"] = sys.argv[1] + if "--help" in sys.argv or "-h" in sys.argv: + arguments["help"] = True + if arguments["subcommand"] in REQUIRED_ARGUMENTS: + arguments["verbosity"] = 0 + if "-v" in sys.argv or "--verbose" in sys.argv: + arguments["verbosity"] = 1 + if "-vv" in sys.argv: + arguments["verbosity"] = 2 + return arguments + + +def get_intro() -> Markdown: + """Return a Markdown object containing the introduction text for Anomalib CLI Guide. + + The introduction text includes a brief description of the guide and links to the Github repository and documentation + + Returns: + A Markdown object containing the introduction text for Anomalib CLI Guide. + """ + intro_markdown = ( + "# Anomalib CLI Guide\n\n" + "Github Repository: [https://github.com/openvinotoolkit/anomalib](https://github.com/openvinotoolkit/anomalib)." + "\n\n" + "A better guide is provided by the [documentation](https://anomalib.readthedocs.io/en/latest/index.html)." + ) + return Markdown(intro_markdown) + + +def get_verbose_usage(subcommand: str = "train") -> str: + """Return a string containing verbose usage information for the specified subcommand. + + Args: + ---- + subcommand (str): The name of the subcommand to get verbose usage information for. Defaults to "train". + + Returns: + ------- + str: A string containing verbose usage information for the specified subcommand. + """ + return ( + "To get more overridable argument information, run the command below.\n" + "```python\n" + "# Verbosity Level 1\n" + f"anomalib {subcommand} [optional_arguments] -h -v\n" + "# Verbosity Level 2\n" + f"anomalib {subcommand} [optional_arguments] -h -vv\n" + "```" + ) + + +def get_cli_usage_docstring(component: object | None) -> str | None: + r"""Get the cli usage from the docstring. + + Args: + ---- + component (Optional[object]): The component to get the docstring from + + Returns: + ------- + Optional[str]: The quick-start guide as Markdown format. + + Example: + ------- + component.__doc__ = ''' + + + CLI Usage: + 1. First Step. + 2. Second Step. + + + ''' + >>> get_cli_usage_docstring(component) + "1. First Step.\n2. Second Step." + """ + if component is None or component.__doc__ is None or "CLI Usage" not in component.__doc__: + return None + + pattern = r"CLI Usage:(.*?)(?=\n{2,}|\Z)" + match = re.search(pattern, component.__doc__, re.DOTALL) + + if match: + contents = match.group(1).strip().split("\n") + return "\n".join([content.strip() for content in contents]) + return None + + +def render_guide(subcommand: str | None = None) -> list: + """Render a guide for the specified subcommand. + + Args: + ---- + subcommand (Optional[str]): The subcommand to render the guide for. + + Returns: + ------- + list: A list of contents to be displayed in the guide. + """ + if subcommand is None or subcommand not in DOCSTRING_USAGE: + return [] + contents = [get_intro()] + target_command = DOCSTRING_USAGE[subcommand] + cli_usage = get_cli_usage_docstring(target_command) + if cli_usage is not None: + cli_usage += f"\n{get_verbose_usage(subcommand)}" + quick_start = Panel(Markdown(cli_usage), border_style="dim", title="Quick-Start", title_align="left") + contents.append(quick_start) + return contents + + +class CustomHelpFormatter(RichHelpFormatter, DefaultHelpFormatter): + """A custom help formatter for Anomalib CLI. + + This formatter extends the RichHelpFormatter and DefaultHelpFormatter classes to provide + a more detailed and customizable help output for Anomalib CLI. + + Attributes: + verbosity_level : int + The level of verbosity for the help output. + subcommand : str | None + The subcommand to render the guide for. + + Methods: + add_usage(usage, actions, *args, **kwargs) + Add usage information to the help output. + add_argument(action) + Add an argument to the help output. + format_help() + Format the help output. + """ + + verbosity_dict = get_verbosity_subcommand() + verbosity_level = verbosity_dict["verbosity"] + subcommand = verbosity_dict["subcommand"] + + def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None: + """Add usage information to the formatter. + + Args: + ---- + usage (str | None): A string describing the usage of the program. + actions (list): An list of argparse.Action objects. + *args (Any): Additional positional arguments to pass to the superclass method. + **kwargs (Any): Additional keyword arguments to pass to the superclass method. + + Returns: + ------- + None + """ + if self.subcommand in REQUIRED_ARGUMENTS: + if self.verbosity_level == 0: + actions = [] + elif self.verbosity_level == 1: + actions = [action for action in actions if action.dest in REQUIRED_ARGUMENTS[self.subcommand]] + + super().add_usage(usage, actions, *args, **kwargs) + + def add_argument(self, action: argparse.Action) -> None: + """Add an argument to the help formatter. + + If the verbose level is set to 0, the argument is not added. + If the verbose level is set to 1 and the argument is not in the non-skip list, the argument is not added. + + Args: + ---- + action (argparse.Action): The action to add to the help formatter. + """ + if self.subcommand in REQUIRED_ARGUMENTS: + if self.verbosity_level == 0: + return + if self.verbosity_level == 1 and action.dest not in REQUIRED_ARGUMENTS[self.subcommand]: + return + super().add_argument(action) + + def format_help(self) -> str: + """Format the help message for the current command and returns it as a string. + + The help message includes information about the command's arguments and options, + as well as any additional information provided by the command's help guide. + + Returns: + str: A string containing the formatted help message. + """ + with self.console.capture() as capture: + section = self._root_section + if self.subcommand in REQUIRED_ARGUMENTS and self.verbosity_level in (0, 1) and len(section.rich_items) > 1: + contents = render_guide(self.subcommand) + for content in contents: + self.console.print(content) + if self.verbosity_level > 0: + if len(section.rich_items) > 1: + section = Panel(section, border_style="dim", title="Arguments", title_align="left") + self.console.print(section, highlight=False, soft_wrap=True) + help_msg = capture.get() + + if help_msg: + help_msg = self._long_break_matcher.sub("\n\n", help_msg).rstrip() + "\n" + return help_msg diff --git a/anomalib/cli/utils/installation.py b/anomalib/cli/utils/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..df5cd974a6efae071c80a536caf32f6b6ae4035d --- /dev/null +++ b/anomalib/cli/utils/installation.py @@ -0,0 +1,430 @@ +"""Anomalib installation util functions.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import platform +import re +from importlib.metadata import requires +from pathlib import Path +from warnings import warn + +from pkg_resources import Requirement + +AVAILABLE_TORCH_VERSIONS = { + "2.0.0": {"torchvision": "0.15.1", "cuda": ("11.7", "11.8")}, + "2.0.1": {"torchvision": "0.15.2", "cuda": ("11.7", "11.8")}, + "2.1.1": {"torchvision": "0.16.1", "cuda": ("11.8", "12.1")}, + "2.1.2": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")}, + "2.2.0": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")}, +} + + +def get_requirements(module: str = "anomalib") -> dict[str, list[Requirement]]: + """Get requirements of module from importlib.metadata. + + This function returns list of required packages from importlib_metadata. + + Example: + >>> get_requirements("anomalib") + { + "base": ["jsonargparse==4.27.1", ...], + "core": ["torch==2.1.1", ...], + ... + } + + Returns: + dict[str, list[Requirement]]: List of required packages for each optional-extras. + """ + requirement_list: list[str] | None = requires(module) + extra_requirement: dict[str, list[Requirement]] = {} + if requirement_list is None: + return extra_requirement + for requirement in requirement_list: + extra = "core" + requirement_extra: list[str] = requirement.replace(" ", "").split(";") + if isinstance(requirement_extra, list) and len(requirement_extra) > 1: + extra = requirement_extra[-1].split("==")[-1].strip("'\"") + _requirement_name = requirement_extra[0] + _requirement = Requirement.parse(_requirement_name) + if extra in extra_requirement: + extra_requirement[extra].append(_requirement) + else: + extra_requirement[extra] = [_requirement] + return extra_requirement + + +def parse_requirements( + requirements: list[Requirement], + skip_torch: bool = False, +) -> tuple[str | None, list[str]]: + """Parse requirements and returns torch and other requirements. + + Args: + requirements (list[Requirement]): List of requirements. + skip_torch (bool): Whether to skip torch requirement. Defaults to False. + + Raises: + ValueError: If torch requirement is not found. + + Examples: + >>> requirements = [ + ... Requirement.parse("torch==1.13.0"), + ... Requirement.parse("onnx>=1.8.1"), + ... ] + >>> parse_requirements(requirements=requirements) + (Requirement.parse("torch==1.13.0"), + Requirement.parse("onnx>=1.8.1")) + + Returns: + tuple[str, list[str], list[str]]: Tuple of torch and other requirements. + """ + torch_requirement: str | None = None + other_requirements: list[str] = [] + + for requirement in requirements: + if requirement.unsafe_name == "torch": + torch_requirement = str(requirement) + if len(requirement.specs) > 1: + warn( + "requirements.txt contains. Please remove other versions of torch from requirements.", + stacklevel=2, + ) + + # Rest of the requirements are task requirements. + # Other torch-related requirements such as `torchvision` are to be excluded. + # This is because torch-related requirements are already handled in torch_requirement. + else: + # if not requirement.unsafe_name.startswith("torch"): + other_requirements.append(str(requirement)) + + if not skip_torch and not torch_requirement: + msg = "Could not find torch requirement. Anoamlib depends on torch. Please add torch to your requirements." + raise ValueError(msg) + + # Get the unique list of the requirements. + other_requirements = list(set(other_requirements)) + + return torch_requirement, other_requirements + + +def get_cuda_version() -> str | None: + """Get CUDA version installed on the system. + + Examples: + >>> # Assume that CUDA version is 11.2 + >>> get_cuda_version() + "11.2" + + >>> # Assume that CUDA is not installed on the system + >>> get_cuda_version() + None + + Returns: + str | None: CUDA version installed on the system. + """ + # 1. Check CUDA_HOME Environment variable + cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") + + if Path(cuda_home).exists(): + # Check $CUDA_HOME/version.json file. + version_file = Path(cuda_home) / "version.json" + if version_file.is_file(): + with Path(version_file).open() as file: + data = json.load(file) + cuda_version = data.get("cuda", {}).get("version", None) + if cuda_version is not None: + cuda_version_parts = cuda_version.split(".") + return ".".join(cuda_version_parts[:2]) + # 2. 'nvcc --version' check & without version.json case + try: + result = os.popen(cmd="nvcc --version") + output = result.read() + + cuda_version_pattern = r"cuda_(\d+\.\d+)" + cuda_version_match = re.search(cuda_version_pattern, output) + + if cuda_version_match is not None: + return cuda_version_match.group(1) + except OSError: + msg = "Could not find cuda-version. Instead, the CPU version of torch will be installed." + warn(msg, stacklevel=2) + return None + + +def update_cuda_version_with_available_torch_cuda_build(cuda_version: str, torch_version: str) -> str: + """Update the installed CUDA version with the highest supported CUDA version by PyTorch. + + Args: + cuda_version (str): The installed CUDA version. + torch_version (str): The PyTorch version. + + Raises: + Warning: If the installed CUDA version is not supported by PyTorch. + + Examples: + >>> update_cuda_version_with_available_torch_cuda_builds("11.1", "1.13.0") + "11.6" + + >>> update_cuda_version_with_available_torch_cuda_builds("11.7", "1.13.0") + "11.7" + + >>> update_cuda_version_with_available_torch_cuda_builds("11.8", "1.13.0") + "11.7" + + >>> update_cuda_version_with_available_torch_cuda_builds("12.1", "2.0.1") + "11.8" + + Returns: + str: The updated CUDA version. + """ + max_supported_cuda = max(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"]) + min_supported_cuda = min(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"]) + bounded_cuda_version = max(min(cuda_version, max_supported_cuda), min_supported_cuda) + + if cuda_version != bounded_cuda_version: + warn( + f"Installed CUDA version is v{cuda_version}. \n" + f"v{min_supported_cuda} <= Supported CUDA version <= v{max_supported_cuda}.\n" + f"This script will use CUDA v{bounded_cuda_version}.\n" + f"However, this may not be safe, and you are advised to install the correct version of CUDA.\n" + f"For more details, refer to https://pytorch.org/get-started/locally/", + stacklevel=2, + ) + cuda_version = bounded_cuda_version + + return cuda_version + + +def get_cuda_suffix(cuda_version: str) -> str: + """Get CUDA suffix for PyTorch versions. + + Args: + cuda_version (str): CUDA version installed on the system. + + Note: + The CUDA version of PyTorch is not always the same as the CUDA version + that is installed on the system. For example, the latest PyTorch + version (1.10.0) supports CUDA 11.3, but the latest CUDA version + that is available for download is 11.2. Therefore, we need to use + the latest available CUDA version for PyTorch instead of the CUDA + version that is installed on the system. Therefore, this function + shoudl be regularly updated to reflect the latest available CUDA. + + Examples: + >>> get_cuda_suffix(cuda_version="11.2") + "cu112" + + >>> get_cuda_suffix(cuda_version="11.8") + "cu118" + + Returns: + str: CUDA suffix for PyTorch or mmX version. + """ + return f"cu{cuda_version.replace('.', '')}" + + +def get_hardware_suffix(with_available_torch_build: bool = False, torch_version: str | None = None) -> str: + """Get hardware suffix for PyTorch or mmX versions. + + Args: + with_available_torch_build (bool): Whether to use the latest available + PyTorch build or not. If True, the latest available PyTorch build + will be used. If False, the installed PyTorch build will be used. + Defaults to False. + torch_version (str | None): PyTorch version. This is only used when the + ``with_available_torch_build`` is True. + + Examples: + >>> # Assume that CUDA version is 11.2 + >>> get_hardware_suffix() + "cu112" + + >>> # Assume that CUDA is not installed on the system + >>> get_hardware_suffix() + "cpu" + + Assume that that installed CUDA version is 12.1. + However, the latest available CUDA version for PyTorch v2.0 is 11.8. + Therefore, we use 11.8 instead of 12.1. This is because PyTorch does not + support CUDA 12.1 yet. In this case, we could correct the CUDA version + by setting `with_available_torch_build` to True. + + >>> cuda_version = get_cuda_version() + "12.1" + >>> get_hardware_suffix(with_available_torch_build=True, torch_version="2.0.1") + "cu118" + + Returns: + str: Hardware suffix for PyTorch or mmX version. + """ + cuda_version = get_cuda_version() + if cuda_version: + if with_available_torch_build: + if torch_version is None: + msg = "``torch_version`` must be provided when with_available_torch_build is True." + raise ValueError(msg) + cuda_version = update_cuda_version_with_available_torch_cuda_build(cuda_version, torch_version) + hardware_suffix = get_cuda_suffix(cuda_version) + else: + hardware_suffix = "cpu" + + return hardware_suffix + + +def add_hardware_suffix_to_torch( + requirement: Requirement, + hardware_suffix: str | None = None, + with_available_torch_build: bool = False, +) -> str: + """Add hardware suffix to the torch requirement. + + Args: + requirement (Requirement): Requirement object comprising requirement + details. + hardware_suffix (str | None): Hardware suffix. If None, it will be set + to the correct hardware suffix. Defaults to None. + with_available_torch_build (bool): To check whether the installed + CUDA version is supported by the latest available PyTorch build. + Defaults to False. + + Examples: + >>> from pkg_resources import Requirement + >>> req = "torch>=1.13.0, <=2.0.1" + >>> requirement = Requirement.parse(req) + >>> requirement.name, requirement.specs + ('torch', [('>=', '1.13.0'), ('<=', '2.0.1')]) + + >>> add_hardware_suffix_to_torch(requirement) + 'torch>=1.13.0+cu121, <=2.0.1+cu121' + + ``with_available_torch_build=True`` will use the latest available PyTorch build. + >>> req = "torch==2.0.1" + >>> requirement = Requirement.parse(req) + >>> add_hardware_suffix_to_torch(requirement, with_available_torch_build=True) + 'torch==2.0.1+cu118' + + It is possible to pass the ``hardware_suffix`` manually. + >>> req = "torch==2.0.1" + >>> requirement = Requirement.parse(req) + >>> add_hardware_suffix_to_torch(requirement, hardware_suffix="cu121") + 'torch==2.0.1+cu111' + + Raises: + ValueError: When the requirement has more than two version criterion. + + Returns: + str: Updated torch package with the right cuda suffix. + """ + name = requirement.unsafe_name + updated_specs: list[str] = [] + + for operator, version in requirement.specs: + hardware_suffix = hardware_suffix or get_hardware_suffix(with_available_torch_build, version) + updated_version = version + f"+{hardware_suffix}" if not version.startswith(("2.1", "2.2")) else version + + # ``specs`` contains operators and versions as follows: + # These are to be concatenated again for the updated version. + updated_specs.append(operator + updated_version) + + updated_requirement: str = "" + + if updated_specs: + # This is the case when specs are e.g. ['<=1.9.1+cu111'] + if len(updated_specs) == 1: + updated_requirement = name + updated_specs[0] + # This is the case when specs are e.g., ['<=1.9.1+cu111', '>=1.8.1+cu111'] + elif len(updated_specs) == 2: + updated_requirement = name + updated_specs[0] + ", " + updated_specs[1] + else: + msg = ( + "Requirement version can be a single value or a range. \n" + "For example it could be torch>=1.8.1 " + "or torch>=1.8.1, <=1.9.1\n" + f"Got {updated_specs} instead." + ) + raise ValueError(msg) + return updated_requirement + + +def get_torch_install_args(requirement: str | Requirement) -> list[str]: + """Get the install arguments for Torch requirement. + + This function will return the install arguments for the Torch requirement + and its corresponding torchvision requirement. + + Args: + requirement (str | Requirement): The torch requirement. + + Raises: + RuntimeError: If the OS is not supported. + + Example: + >>> from pkg_resources import Requirement + >>> requriment = "torch>=1.13.0" + >>> get_torch_install_args(requirement) + ['--extra-index-url', 'https://download.pytorch.org/whl/cpu', + 'torch==1.13.0+cpu', 'torchvision==0.14.0+cpu'] + + Returns: + list[str]: The install arguments. + """ + if isinstance(requirement, str): + requirement = Requirement.parse(requirement) + + # NOTE: This does not take into account if the requirement has multiple versions + # such as torch<2.0.1,>=1.13.0 + if len(requirement.specs) < 1: + return [str(requirement)] + select_spec_idx = 0 + for i, spec in enumerate(requirement.specs): + if "=" in spec[0]: + select_spec_idx = i + break + operator, version = requirement.specs[select_spec_idx] + if version not in AVAILABLE_TORCH_VERSIONS: + version = max(AVAILABLE_TORCH_VERSIONS.keys()) + warn( + f"Torch Version will be selected as {version}.", + stacklevel=2, + ) + install_args: list[str] = [] + + if platform.system() in ("Linux", "Windows"): + # Get the hardware suffix (eg., +cpu, +cu116 and +cu118 etc.) + hardware_suffix = get_hardware_suffix(with_available_torch_build=True, torch_version=version) + + # Create the PyTorch Index URL to download the correct wheel. + index_url = f"https://download.pytorch.org/whl/{hardware_suffix}" + + # Create the PyTorch version depending on the CUDA version. For example, + # If CUDA version is 11.2, then the PyTorch version is 1.8.0+cu112. + # If CUDA version is None, then the PyTorch version is 1.8.0+cpu. + torch_version = add_hardware_suffix_to_torch(requirement, hardware_suffix, with_available_torch_build=True) + + # Get the torchvision version depending on the torch version. + torchvision_version = AVAILABLE_TORCH_VERSIONS[version]["torchvision"] + torchvision_requirement = f"torchvision{operator}{torchvision_version}" + if isinstance(torchvision_version, str) and not torchvision_version.startswith("0.16"): + torchvision_requirement += f"+{hardware_suffix}" + + # Return the install arguments. + install_args += [ + "--extra-index-url", + # "--index-url", + index_url, + torch_version, + torchvision_requirement, + ] + elif platform.system() in ("macos", "Darwin"): + torch_version = str(requirement) + install_args += [torch_version] + else: + msg = f"Unsupported OS: {platform.system()}" + raise RuntimeError(msg) + + return install_args diff --git a/anomalib/cli/utils/openvino.py b/anomalib/cli/utils/openvino.py new file mode 100644 index 0000000000000000000000000000000000000000..65ac7b80db1e332a8e6f7235b1dbefb2a352b682 --- /dev/null +++ b/anomalib/cli/utils/openvino.py @@ -0,0 +1,32 @@ +"""Utils for OpenVINO parser.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from jsonargparse import ArgumentParser + +from anomalib.utils.exceptions import try_import + +logger = logging.getLogger(__name__) + + +if try_import("openvino"): + from openvino.tools.ovc.cli_parser import get_common_cli_parser +else: + get_common_cli_parser = None + + +def add_openvino_export_arguments(parser: ArgumentParser) -> None: + """Add OpenVINO arguments to parser under --mo key.""" + if get_common_cli_parser is not None: + group = parser.add_argument_group("OpenVINO Model Optimizer arguments (optional)") + ov_parser = get_common_cli_parser() + # remove redundant keys from mo keys + for arg in ov_parser._actions: # noqa: SLF001 + if arg.dest in ("help", "input_model", "output_dir"): + continue + group.add_argument(f"--ov_args.{arg.dest}", type=arg.type, default=arg.default, help=arg.help) + else: + logger.info("OpenVINO is possibly not installed in the environment. Skipping adding it to parser.") diff --git a/anomalib/data/__init__.py b/anomalib/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a58f5e58945d236ac64e554e5967722d656c526 --- /dev/null +++ b/anomalib/data/__init__.py @@ -0,0 +1,72 @@ +"""Anomalib Datasets.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import importlib +import logging +from enum import Enum +from itertools import chain + +from omegaconf import DictConfig, ListConfig + +from anomalib.utils.config import to_tuple + +from .base import AnomalibDataModule, AnomalibDataset +from .depth import DepthDataFormat, Folder3D, MVTec3D +from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa +from .predict import PredictDataset +from .utils import LabelName +from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat + +logger = logging.getLogger(__name__) + + +DataFormat = Enum( # type: ignore[misc] + "DataFormat", + {i.name: i.value for i in chain(DepthDataFormat, ImageDataFormat, VideoDataFormat)}, +) + + +def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: + """Get Anomaly Datamodule. + + Args: + config (DictConfig | ListConfig): Configuration of the anomaly model. + + Returns: + PyTorch Lightning DataModule + """ + logger.info("Loading the datamodule") + + module = importlib.import_module(".".join(config.data.class_path.split(".")[:-1])) + dataclass = getattr(module, config.data.class_path.split(".")[-1]) + init_args = {**config.data.get("init_args", {})} # get dict + if "image_size" in init_args: + init_args["image_size"] = to_tuple(init_args["image_size"]) + + return dataclass(**init_args) + + +__all__ = [ + "AnomalibDataset", + "AnomalibDataModule", + "DepthDataFormat", + "ImageDataFormat", + "VideoDataFormat", + "get_datamodule", + "BTech", + "Folder", + "Folder3D", + "PredictDataset", + "Kolektor", + "MVTec", + "MVTec3D", + "MVTecLoco", + "Avenue", + "UCSDped", + "ShanghaiTech", + "Visa", + "LabelName", +] diff --git a/anomalib/data/base/__init__.py b/anomalib/data/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00d67a7ea3bb6fa88ea42521d953398360cffa91 --- /dev/null +++ b/anomalib/data/base/__init__.py @@ -0,0 +1,18 @@ +"""Base classes for custom dataset and datamodules.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from .datamodule import AnomalibDataModule +from .dataset import AnomalibDataset +from .depth import AnomalibDepthDataset +from .video import AnomalibVideoDataModule, AnomalibVideoDataset + +__all__ = [ + "AnomalibDataset", + "AnomalibDataModule", + "AnomalibVideoDataset", + "AnomalibVideoDataModule", + "AnomalibDepthDataset", +] diff --git a/anomalib/data/base/datamodule.py b/anomalib/data/base/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5063d8ab9870844b0f7c4347eeca9579761283 --- /dev/null +++ b/anomalib/data/base/datamodule.py @@ -0,0 +1,305 @@ +"""Anomalib datamodule base class.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from lightning.pytorch import LightningDataModule +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils.data.dataloader import DataLoader, default_collate +from torchvision.transforms.v2 import Resize, Transform + +from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label +from anomalib.data.utils.synthetic import SyntheticAnomalyDataset + +if TYPE_CHECKING: + from pandas import DataFrame + + from anomalib.data.base.dataset import AnomalibDataset + +logger = logging.getLogger(__name__) + + +def collate_fn(batch: list) -> dict[str, Any]: + """Collate bounding boxes as lists. + + Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists, + the `mask_path` is also collated as a list since each element in the batch could be unequal. + For all other entries, the default collate function is used. + + Args: + batch (List): list of items in the batch where len(batch) is equal to the batch size. + + Returns: + dict[str, Any]: Dictionary containing the collated batch information. + """ + elem = batch[0] # sample an element from the batch to check the type. + out_dict = {} + if isinstance(elem, dict): + if "boxes" in elem: + # collate boxes as list + out_dict["boxes"] = [item.pop("boxes") for item in batch] + if "semantic_mask" in elem: + # semantic masks have a variable number of channels, so we collate them as a list + out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch] + if "mask_path" in elem and isinstance(elem["mask_path"], list): + # collate mask paths as list + out_dict["mask_path"] = [item.pop("mask_path") for item in batch] + # collate other data normally + out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem}) + return out_dict + return default_collate(batch) + + +class AnomalibDataModule(LightningDataModule, ABC): + """Base Anomalib data module. + + Args: + train_batch_size (int): Batch size used by the train dataloader. + eval_batch_size (int): Batch size used by the val and test dataloaders. + num_workers (int): Number of workers used by the train, val and test dataloaders. + val_split_mode (ValSplitMode): Determines how the validation split is obtained. + Options: [none, same_as_test, from_test, synthetic] + val_split_ratio (float): Fraction of the train or test images held our for validation. + test_split_mode (Optional[TestSplitMode], optional): Determines how the test split is obtained. + Options: [none, from_dir, synthetic]. + Defaults to ``None``. + test_split_ratio (float): Fraction of the train images held out for testing. + Defaults to ``None``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + seed (int | None, optional): Seed used during random subset splitting. + Defaults to ``None``. + """ + + def __init__( + self, + train_batch_size: int, + eval_batch_size: int, + num_workers: int, + val_split_mode: ValSplitMode | str, + val_split_ratio: float, + test_split_mode: TestSplitMode | str | None = None, + test_split_ratio: float | None = None, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + seed: int | None = None, + ) -> None: + super().__init__() + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.num_workers = num_workers + self.test_split_mode = TestSplitMode(test_split_mode) if test_split_mode else TestSplitMode.NONE + self.test_split_ratio = test_split_ratio + self.val_split_mode = ValSplitMode(val_split_mode) + self.val_split_ratio = val_split_ratio + self.image_size = image_size + self.seed = seed + + # set transforms + if bool(train_transform) != bool(eval_transform): + msg = "Only one of train_transform and eval_transform was specified. This is not recommended because \ + it could lead to unexpected behaviour. Please ensure training and eval transforms have the same \ + reshape and normalization characteristics." + logger.warning(msg) + self._train_transform = train_transform or transform + self._eval_transform = eval_transform or transform + + self.train_data: AnomalibDataset + self.val_data: AnomalibDataset + self.test_data: AnomalibDataset + + self._samples: DataFrame | None = None + self._category: str = "" + + self._is_setup = False # flag to track if setup has been called from the trainer + + @property + def name(self) -> str: + """Name of the datamodule.""" + return self.__class__.__name__ + + def setup(self, stage: str | None = None) -> None: + """Set up train, validation and test data. + + Args: + stage: str | None: Train/Val/Test stages. + Defaults to ``None``. + """ + has_subset = any(hasattr(self, subset) for subset in ["train_data", "val_data", "test_data"]) + if not has_subset or not self._is_setup: + self._setup(stage) + self._create_test_split() + self._create_val_split() + if isinstance(stage, TrainerFn): + # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer + self._is_setup = True + + @abstractmethod + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform dynamic subset splitting. + + This method may be overridden in subclass for custom splitting behaviour. + + Note: + The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule + subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset + splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and + the test set must therefore be created as early as the `fit` stage. + + """ + raise NotImplementedError + + @property + def category(self) -> str: + """Get the category of the datamodule.""" + return self._category + + @category.setter + def category(self, category: str) -> None: + """Set the category of the datamodule.""" + self._category = category + + def _create_test_split(self) -> None: + """Obtain the test set based on the settings in the config.""" + if self.test_data.has_normal: + # split the test data into normal and anomalous so these can be processed separately + normal_test_data, self.test_data = split_by_label(self.test_data) + elif self.test_split_mode != TestSplitMode.NONE: + # when the user did not provide any normal images for testing, we sample some from the training set, + # except when the user explicitly requested no test splitting. + logger.info( + "No normal test images found. Sampling from training set using a split ratio of %0.2f", + self.test_split_ratio, + ) + if self.test_split_ratio is not None: + self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio, seed=self.seed) + + if self.test_split_mode == TestSplitMode.FROM_DIR: + self.test_data += normal_test_data + elif self.test_split_mode == TestSplitMode.SYNTHETIC: + self.test_data = SyntheticAnomalyDataset.from_dataset(normal_test_data) + elif self.test_split_mode != TestSplitMode.NONE: + msg = f"Unsupported Test Split Mode: {self.test_split_mode}" + raise ValueError(msg) + + def _create_val_split(self) -> None: + """Obtain the validation set based on the settings in the config.""" + if self.val_split_mode == ValSplitMode.FROM_TRAIN: + # randomly sampled from train set + self.train_data, self.val_data = random_split( + self.train_data, + self.val_split_ratio, + label_aware=True, + seed=self.seed, + ) + elif self.val_split_mode == ValSplitMode.FROM_TEST: + # randomly sampled from test set + self.test_data, self.val_data = random_split( + self.test_data, + self.val_split_ratio, + label_aware=True, + seed=self.seed, + ) + elif self.val_split_mode == ValSplitMode.SAME_AS_TEST: + # equal to test set + self.val_data = self.test_data + elif self.val_split_mode == ValSplitMode.SYNTHETIC: + # converted from random training sample + self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) + self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) + elif self.val_split_mode == ValSplitMode.FROM_DIR: + # the val_data is prepared in subclass + assert hasattr( + self, + "val_data", + ), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup." + elif self.val_split_mode != ValSplitMode.NONE: + msg = f"Unknown validation split mode: {self.val_split_mode}" + raise ValueError(msg) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Get train dataloader.""" + return DataLoader( + dataset=self.train_data, + shuffle=True, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + ) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """Get validation dataloader.""" + return DataLoader( + dataset=self.val_data, + shuffle=False, + batch_size=self.eval_batch_size, + num_workers=self.num_workers, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> EVAL_DATALOADERS: + """Get test dataloader.""" + return DataLoader( + dataset=self.test_data, + shuffle=False, + batch_size=self.eval_batch_size, + num_workers=self.num_workers, + collate_fn=collate_fn, + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + """Use the test dataloader for inference unless overridden.""" + return self.test_dataloader() + + @property + def transform(self) -> Transform: + """Property that returns the user-specified transform for the datamodule, if any. + + This property is accessed by the engine to set the transform for the model. The eval_transform takes precedence + over the train_transform, because the transform that we store in the model is the one that should be used during + inference. + """ + if self._eval_transform: + return self._eval_transform + return None + + @property + def train_transform(self) -> Transform: + """Get the transforms that will be passed to the train dataset. + + If the train_transform is not set, the engine will request the transform from the model. + """ + if self._train_transform: + return self._train_transform + if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform: + return self.trainer.model.transform + if self.image_size: + return Resize(self.image_size, antialias=True) + return None + + @property + def eval_transform(self) -> Transform: + """Get the transform that will be passed to the val/test/predict datasets. + + If the eval_transform is not set, the engine will request the transform from the model. + """ + if self._eval_transform: + return self._eval_transform + if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform: + return self.trainer.model.transform + if self.image_size: + return Resize(self.image_size, antialias=True) + return None diff --git a/anomalib/data/base/dataset.py b/anomalib/data/base/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfba278ac6f2eb9fb5eef4e749cf990cda86635 --- /dev/null +++ b/anomalib/data/base/dataset.py @@ -0,0 +1,208 @@ +"""Anomalib dataset base class.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import copy +import logging +from abc import ABC +from collections.abc import Sequence +from pathlib import Path + +import pandas as pd +import torch +from pandas import DataFrame +from torch.utils.data import Dataset +from torchvision.transforms.v2 import Transform +from torchvision.tv_tensors import Mask + +from anomalib import TaskType +from anomalib.data.utils import LabelName, masks_to_boxes, read_image, read_mask + +_EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"] +_EXPECTED_COLUMNS_SEGMENTATION = [*_EXPECTED_COLUMNS_CLASSIFICATION, "mask_path"] +_EXPECTED_COLUMNS_PERTASK = { + "classification": _EXPECTED_COLUMNS_CLASSIFICATION, + "segmentation": _EXPECTED_COLUMNS_SEGMENTATION, + "detection": _EXPECTED_COLUMNS_SEGMENTATION, +} + +logger = logging.getLogger(__name__) + + +class AnomalibDataset(Dataset, ABC): + """Anomalib dataset. + + The dataset is based on a dataframe that contains the information needed by the dataloader to load each of + the dataset items into memory. + + The samples dataframe must be set from the subclass using the setter of the `samples` property. + + The DataFrame must, at least, include the following columns: + - `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test'). + - `image_path` (str): Path to the file system location where the image is stored. + - `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'. + - `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only). + Required if task is 'segmentation'. + + Example DataFrame: + +---+-------------------+-----------+-------------+------------------+-------+ + | | image_path | label | label_index | mask_path | split | + +---+-------------------+-----------+-------------+------------------+-------+ + | 0 | path/to/image.png | anomalous | 1 | path/to/mask.png | train | + +---+-------------------+-----------+-------------+------------------+-------+ + + Note: + The example above is illustrative and may need to be adjusted based on the specific dataset structure. + + Args: + task (str): Task type, either 'classification' or 'segmentation' + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + """ + + def __init__(self, task: TaskType | str, transform: Transform | None = None) -> None: + super().__init__() + self.task = TaskType(task) + self.transform = transform + self._samples: DataFrame | None = None + self._category: str | None = None + + @property + def name(self) -> str: + """Name of the dataset.""" + class_name = self.__class__.__name__ + + # Remove the `_dataset` suffix from the class name + if class_name.endswith("Dataset"): + class_name = class_name[:-7] + + return class_name + + def __len__(self) -> int: + """Get length of the dataset.""" + return len(self.samples) + + def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset": + """Subsamples the dataset at the provided indices. + + Args: + indices (Sequence[int]): Indices at which the dataset is to be subsampled. + inplace (bool): When true, the subsampling will be performed on the instance itself. + Defaults to ``False``. + """ + if len(set(indices)) != len(indices): + msg = "No duplicates allowed in indices." + raise ValueError(msg) + dataset = self if inplace else copy.deepcopy(self) + dataset.samples = self.samples.iloc[indices].reset_index(drop=True) + return dataset + + @property + def samples(self) -> DataFrame: + """Get the samples dataframe.""" + if self._samples is None: + msg = ( + "Dataset does not have a samples dataframe. Ensure that a dataframe has been assigned to " + "`dataset.samples`." + ) + raise RuntimeError(msg) + return self._samples + + @samples.setter + def samples(self, samples: DataFrame) -> None: + """Overwrite the samples with a new dataframe. + + Args: + samples (DataFrame): DataFrame with new samples. + """ + # validate the passed samples by checking the + if not isinstance(samples, DataFrame): + msg = f"samples must be a pandas.DataFrame, found {type(samples)}" + raise TypeError(msg) + + expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task] + if not all(col in samples.columns for col in expected_columns): + msg = f"samples must have (at least) columns {expected_columns}, found {samples.columns}" + raise ValueError(msg) + + if not samples["image_path"].apply(lambda p: Path(p).exists()).all(): + msg = "missing file path(s) in samples" + raise FileNotFoundError(msg) + + self._samples = samples.sort_values(by="image_path", ignore_index=True) + + @property + def category(self) -> str | None: + """Get the category of the dataset.""" + return self._category + + @category.setter + def category(self, category: str) -> None: + """Set the category of the dataset.""" + self._category = category + + @property + def has_normal(self) -> bool: + """Check if the dataset contains any normal samples.""" + return LabelName.NORMAL in list(self.samples.label_index) + + @property + def has_anomalous(self) -> bool: + """Check if the dataset contains any anomalous samples.""" + return LabelName.ABNORMAL in list(self.samples.label_index) + + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Get dataset item for the index ``index``. + + Args: + index (int): Index to get the item. + + Returns: + dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path, + target path, image tensor, label and transformed bounding box. + """ + image_path = self.samples.iloc[index].image_path + mask_path = self.samples.iloc[index].mask_path + label_index = self.samples.iloc[index].label_index + + image = read_image(image_path, as_tensor=True) + item = {"image_path": image_path, "label": label_index} + + if self.task == TaskType.CLASSIFICATION: + item["image"] = self.transform(image) if self.transform else image + elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + mask = ( + Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) + if label_index == LabelName.NORMAL + else read_mask(mask_path, as_tensor=True) + ) + item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask) + + if self.task == TaskType.DETECTION: + # create boxes from masks for detection task + boxes, _ = masks_to_boxes(item["mask"]) + item["boxes"] = boxes[0] + else: + msg = f"Unknown task type: {self.task}" + raise ValueError(msg) + + return item + + def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset": + """Concatenate this dataset with another dataset. + + Args: + other_dataset (AnomalibDataset): Dataset to concatenate with. + + Returns: + AnomalibDataset: Concatenated dataset. + """ + if not isinstance(other_dataset, self.__class__): + msg = "Cannot concatenate datasets that are not of the same type." + raise TypeError(msg) + dataset = copy.deepcopy(self) + dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True) + return dataset diff --git a/anomalib/data/base/depth.py b/anomalib/data/base/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd5377cb63b475ac690bb19675c3772114ab1c9 --- /dev/null +++ b/anomalib/data/base/depth.py @@ -0,0 +1,76 @@ +"""Base Depth Dataset.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC + +import torch +from PIL import Image +from torchvision.transforms.functional import to_tensor +from torchvision.transforms.v2 import Transform +from torchvision.tv_tensors import Mask + +from anomalib import TaskType +from anomalib.data.base.dataset import AnomalibDataset +from anomalib.data.utils import LabelName, masks_to_boxes, read_depth_image + + +class AnomalibDepthDataset(AnomalibDataset, ABC): + """Base depth anomalib dataset class. + + Args: + task (str): Task type, either 'classification' or 'segmentation' + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + """ + + def __init__(self, task: TaskType, transform: Transform | None = None) -> None: + super().__init__(task, transform) + + self.transform = transform + + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Return rgb image, depth image and mask. + + Args: + index (int): Index of the item to be returned. + + Returns: + dict[str, str | torch.Tensor]: Dictionary containing the image, depth image and mask. + """ + image_path = self.samples.iloc[index].image_path + mask_path = self.samples.iloc[index].mask_path + label_index = self.samples.iloc[index].label_index + depth_path = self.samples.iloc[index].depth_path + + image = to_tensor(Image.open(image_path)) + depth_image = to_tensor(read_depth_image(depth_path)) + item = {"image_path": image_path, "depth_path": depth_path, "label": label_index} + + if self.task == TaskType.CLASSIFICATION: + item["image"], item["depth_image"] = ( + self.transform(image, depth_image) if self.transform else (image, depth_image) + ) + elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + mask = ( + Mask(torch.zeros(image.shape[-2:])) + if label_index == LabelName.NORMAL + else Mask(to_tensor(Image.open(mask_path)).squeeze()) + ) + item["image"], item["depth_image"], item["mask"] = ( + self.transform(image, depth_image, mask) if self.transform else (image, depth_image, mask) + ) + item["mask_path"] = mask_path + + if self.task == TaskType.DETECTION: + # create boxes from masks for detection task + boxes, _ = masks_to_boxes(item["mask"]) + item["boxes"] = boxes[0] + else: + msg = f"Unknown task type: {self.task}" + raise ValueError(msg) + + return item diff --git a/anomalib/data/base/video.py b/anomalib/data/base/video.py new file mode 100644 index 0000000000000000000000000000000000000000..5f04ebfe3b12963fbbc8e8cda91271c6c802042e --- /dev/null +++ b/anomalib/data/base/video.py @@ -0,0 +1,213 @@ +"""Base Video Dataset.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from enum import Enum +from typing import TYPE_CHECKING, Any + +import torch +from pandas import DataFrame +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2.functional import to_dtype_video +from torchvision.tv_tensors import Mask + +from anomalib import TaskType +from anomalib.data.base.datamodule import AnomalibDataModule +from anomalib.data.base.dataset import AnomalibDataset +from anomalib.data.utils import ValSplitMode, masks_to_boxes +from anomalib.data.utils.video import ClipsIndexer + +if TYPE_CHECKING: + from collections.abc import Callable + + +class VideoTargetFrame(str, Enum): + """Target frame for a video-clip. + + Used in multi-frame models to determine which frame's ground truth information will be used. + """ + + FIRST = "first" + LAST = "last" + MID = "mid" + ALL = "all" + + +class AnomalibVideoDataset(AnomalibDataset, ABC): + """Base video anomalib dataset class. + + Args: + task (str): Task type, either 'classification' or 'segmentation' + clip_length_in_frames (int): Number of video frames in each clip. + frames_between_clips (int): Number of frames between each consecutive video clip. + transform (Transform, optional): Transforms that should be applied to the input clips. + Defaults to ``None``. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. + Defaults to ``VideoTargetFrame.LAST``. + """ + + def __init__( + self, + task: TaskType, + clip_length_in_frames: int, + frames_between_clips: int, + transform: Transform | None = None, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + ) -> None: + super().__init__(task, transform) + + self.clip_length_in_frames = clip_length_in_frames + self.frames_between_clips = frames_between_clips + self.transform = transform + + self.indexer: ClipsIndexer | None = None + self.indexer_cls: Callable | None = None + + self.target_frame = target_frame + + def __len__(self) -> int: + """Get length of the dataset.""" + if not isinstance(self.indexer, ClipsIndexer): + msg = "self.indexer must be an instance of ClipsIndexer." + raise TypeError(msg) + return self.indexer.num_clips() + + @property + def samples(self) -> DataFrame: + """Get the samples dataframe.""" + return super().samples + + @samples.setter + def samples(self, samples: DataFrame) -> None: + """Overwrite samples and re-index subvideos. + + Args: + samples (DataFrame): DataFrame with new samples. + + Raises: + ValueError: If the indexer class is not set. + """ + super(AnomalibVideoDataset, self.__class__).samples.fset(self, samples) # type: ignore[attr-defined] + self._setup_clips() + + def _setup_clips(self) -> None: + """Compute the video and frame indices of the subvideos. + + Should be called after each change to self._samples + """ + if not callable(self.indexer_cls): + msg = "self.indexer_cls must be callable." + raise TypeError(msg) + self.indexer = self.indexer_cls( # pylint: disable=not-callable + video_paths=list(self.samples.image_path), + mask_paths=list(self.samples.mask_path), + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + ) + + def _select_targets(self, item: dict[str, Any]) -> dict[str, Any]: + """Select the target frame from the clip. + + Args: + item (dict[str, Any]): Item containing the clip information. + + Raises: + ValueError: If the target frame is not one of the supported options. + + Returns: + dict[str, Any]: Selected item from the clip. + """ + if self.target_frame == VideoTargetFrame.FIRST: + idx = 0 + elif self.target_frame == VideoTargetFrame.LAST: + idx = -1 + elif self.target_frame == VideoTargetFrame.MID: + idx = int(self.clip_length_in_frames / 2) + else: + msg = f"Unknown video target frame: {self.target_frame}" + raise ValueError(msg) + + if item.get("mask") is not None: + item["mask"] = item["mask"][idx, ...] + if item.get("boxes") is not None: + item["boxes"] = item["boxes"][idx] + if item.get("label") is not None: + item["label"] = item["label"][idx] + if item.get("original_image") is not None: + item["original_image"] = item["original_image"][idx] + if item.get("frames") is not None: + item["frames"] = item["frames"][idx] + return item + + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Get the dataset item for the index ``index``. + + Args: + index (int): Index of the item to be returned. + + Returns: + dict[str, str | torch.Tensor]: Dictionary containing the mask, clip and file system information. + """ + if not isinstance(self.indexer, ClipsIndexer): + msg = "self.indexer must be an instance of ClipsIndexer." + raise TypeError(msg) + item = self.indexer.get_item(index) + item["image"] = to_dtype_video(video=item["image"], scale=True) + # include the untransformed image for visualization + item["original_image"] = item["image"].to(torch.uint8) + + # apply transforms + if item.get("mask") is not None: + if self.transform: + item["image"], item["mask"] = self.transform(item["image"], Mask(item["mask"])) + item["label"] = torch.Tensor([1 in frame for frame in item["mask"]]).int().squeeze(0) + if self.task == TaskType.DETECTION: + item["boxes"], _ = masks_to_boxes(item["mask"]) + item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"] + elif self.transform: + item["image"] = self.transform(item["image"]) + + # squeeze temporal dimensions in case clip length is 1 + item["image"] = item["image"].squeeze(0) + + # include only target frame in gt + if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL: + item = self._select_targets(item) + + if item["mask"] is None: + item.pop("mask") + + return item + + +class AnomalibVideoDataModule(AnomalibDataModule): + """Base class for video data modules.""" + + def _create_test_split(self) -> None: + """Video datamodules do not support dynamic assignment of the test split.""" + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform dynamic subset splitting. + + This method may be overridden in subclass for custom splitting behaviour. + + Video datamodules are not compatible with synthetic anomaly generation. + """ + if self.train_data is None: + msg = "self.train_data cannot be None." + raise ValueError(msg) + + if self.test_data is None: + msg = "self.test_data cannot be None." + raise ValueError(msg) + + self.train_data.setup() + self.test_data.setup() + + if self.val_split_mode == ValSplitMode.SYNTHETIC: + msg = f"Val split mode {self.test_split_mode} not supported for video datasets." + raise ValueError(msg) + + self._create_val_split() diff --git a/anomalib/data/depth/__init__.py b/anomalib/data/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16674ed0c300cad9681d663dfed7c6bbacf5a52a --- /dev/null +++ b/anomalib/data/depth/__init__.py @@ -0,0 +1,20 @@ +"""Anomalib Depth Datasets.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +from .folder_3d import Folder3D +from .mvtec_3d import MVTec3D + + +class DepthDataFormat(str, Enum): + """Supported Depth Dataset Types.""" + + MVTEC_3D = "mvtec_3d" + FOLDER_3D = "folder_3d" + + +__all__ = ["Folder3D", "MVTec3D"] diff --git a/anomalib/data/depth/folder_3d.py b/anomalib/data/depth/folder_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..39d8ccab76234622476e1dbdca6f661ae77075d5 --- /dev/null +++ b/anomalib/data/depth/folder_3d.py @@ -0,0 +1,433 @@ +"""Custom Folder Dataset. + +This script creates a custom dataset from a folder. +""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path + +from pandas import DataFrame, isna +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DirType, + LabelName, + Split, + TestSplitMode, + ValSplitMode, +) +from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path + + +def make_folder3d_dataset( # noqa: C901 + normal_dir: str | Path, + root: str | Path | None = None, + abnormal_dir: str | Path | None = None, + normal_test_dir: str | Path | None = None, + mask_dir: str | Path | None = None, + normal_depth_dir: str | Path | None = None, + abnormal_depth_dir: str | Path | None = None, + normal_test_depth_dir: str | Path | None = None, + split: str | Split | None = None, + extensions: tuple[str, ...] | None = None, +) -> DataFrame: + """Make Folder Dataset. + + Args: + normal_dir (str | Path): Path to the directory containing normal images. + root (str | Path | None): Path to the root directory of the dataset. + Defaults to ``None``. + abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images. + Defaults to ``None``. + normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test + dataset. Normal test images will be a split of `normal_dir` if `None`. + Defaults to ``None``. + mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations. + Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Path to the directory containing + normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` + Defaults to ``None``. + abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for + the test dataset. + Defaults to ``None``. + normal_test_depth_dir (str | Path | None, optional): Path to the directory containing normal depth images for + the test dataset. Normal test images will be a split of `normal_dir` if `None`. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST). + Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + Defaults to ``None``. + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + """ + normal_dir = validate_and_resolve_path(normal_dir, root) + abnormal_dir = validate_and_resolve_path(abnormal_dir, root) if abnormal_dir else None + normal_test_dir = validate_and_resolve_path(normal_test_dir, root) if normal_test_dir else None + mask_dir = validate_and_resolve_path(mask_dir, root) if mask_dir else None + normal_depth_dir = validate_and_resolve_path(normal_depth_dir, root) if normal_depth_dir else None + abnormal_depth_dir = validate_and_resolve_path(abnormal_depth_dir, root) if abnormal_depth_dir else None + normal_test_depth_dir = validate_and_resolve_path(normal_test_depth_dir, root) if normal_test_depth_dir else None + + if not normal_dir.is_dir(): + msg = "A folder location must be provided in normal_dir." + raise ValueError(msg) + + filenames = [] + labels = [] + dirs = {DirType.NORMAL: normal_dir} + + if abnormal_dir: + dirs[DirType.ABNORMAL] = abnormal_dir + + if normal_test_dir: + dirs[DirType.NORMAL_TEST] = normal_test_dir + + if normal_depth_dir: + dirs[DirType.NORMAL_DEPTH] = normal_depth_dir + + if abnormal_depth_dir: + dirs[DirType.ABNORMAL_DEPTH] = abnormal_depth_dir + + if normal_test_depth_dir: + dirs[DirType.NORMAL_TEST_DEPTH] = normal_test_depth_dir + + if mask_dir: + dirs[DirType.MASK] = mask_dir + + for dir_type, path in dirs.items(): + filename, label = _prepare_files_labels(path, dir_type, extensions) + filenames += filename + labels += label + + samples = DataFrame({"image_path": filenames, "label": labels}) + samples = samples.sort_values(by="image_path", ignore_index=True) + + # Create label index for normal (0) and abnormal (1) images. + samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), + "label_index", + ] = LabelName.NORMAL + samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype("Int64") + + # If a path to mask is provided, add it to the sample dataframe. + if normal_depth_dir: + samples.loc[samples.label == DirType.NORMAL, "depth_path"] = samples.loc[ + samples.label == DirType.NORMAL_DEPTH + ].image_path.to_numpy() + samples.loc[samples.label == DirType.ABNORMAL, "depth_path"] = samples.loc[ + samples.label == DirType.ABNORMAL_DEPTH + ].image_path.to_numpy() + + if normal_test_dir: + samples.loc[samples.label == DirType.NORMAL_TEST, "depth_path"] = samples.loc[ + samples.label == DirType.NORMAL_TEST_DEPTH + ].image_path.to_numpy() + + # make sure every rgb image has a corresponding depth image and that the file exists + mismatch = ( + samples.loc[samples.label_index == LabelName.ABNORMAL] + .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1) + .all() + ) + if not mismatch: + msg = """Mismatch between anomalous images and depth images. Make sure the mask files + in 'xyz' folder follow the same naming convention as the anomalous images in the dataset + (e.g. image: '000.png', depth: '000.tiff').""" + raise MisMatchError(msg) + + missing_depth_files = samples.depth_path.apply( + lambda x: Path(x).exists() if not isna(x) else True, + ).all() + if not missing_depth_files: + msg = "Missing depth image files." + raise FileNotFoundError(msg) + + samples = samples.astype({"depth_path": "str"}) + + # If a path to mask is provided, add it to the sample dataframe. + if mask_dir and abnormal_dir: + samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[ + samples.label == DirType.MASK + ].image_path.to_numpy() + samples["mask_path"] = samples["mask_path"].fillna("") + samples = samples.astype({"mask_path": "str"}) + + # make sure all the files exist + if not samples.mask_path.apply( + lambda x: Path(x).exists() if x != "" else True, + ).all(): + msg = f"Missing mask files. mask_dir={mask_dir}" + raise FileNotFoundError(msg) + else: + samples["mask_path"] = "" + + # remove all the rows with temporal image samples that have already been assigned + samples = samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST) + ] + + # Ensure the pathlib objects are converted to str. + # This is because torch dataloader doesn't like pathlib. + samples = samples.astype({"image_path": "str"}) + + # Create train/test split. + # By default, all the normal samples are assigned as train. + # and all the abnormal samples are test. + samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN + samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST + + # Get the data frame for the split. + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class Folder3DDataset(AnomalibDepthDataset): + """Folder dataset. + + Args: + name (str): Name of the dataset. + task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``). + transform (Transform): Transforms that should be applied to the input images. + normal_dir (str | Path): Path to the directory containing normal images. + root (str | Path | None): Root folder of the dataset. + Defaults to ``None``. + abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images. + Defaults to ``None``. + normal_test_dir (str | Path | None, optional): Path to the directory containing + normal images for the test dataset. + Defaults to ``None``. + mask_dir (str | Path | None, optional): Path to the directory containing + the mask annotations. + Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Path to the directory containing + normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` + Defaults to ``None``. + abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for + the test dataset. + Defaults to ``None``. + normal_test_depth_dir (str | Path | None, optional): Path to the directory containing + normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Fixed subset split that follows from folder structure on file system. + Choose from [Split.FULL, Split.TRAIN, Split.TEST] + Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + Defaults to ``None``. + + Raises: + ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is + provided, `task` should be set to `segmentation`. + """ + + def __init__( + self, + name: str, + task: TaskType, + normal_dir: str | Path, + root: str | Path | None = None, + abnormal_dir: str | Path | None = None, + normal_test_dir: str | Path | None = None, + mask_dir: str | Path | None = None, + normal_depth_dir: str | Path | None = None, + abnormal_depth_dir: str | Path | None = None, + normal_test_depth_dir: str | Path | None = None, + transform: Transform | None = None, + split: str | Split | None = None, + extensions: tuple[str, ...] | None = None, + ) -> None: + super().__init__(task, transform) + + self._name = name + self.split = split + self.root = root + self.normal_dir = normal_dir + self.abnormal_dir = abnormal_dir + self.normal_test_dir = normal_test_dir + self.mask_dir = mask_dir + self.normal_depth_dir = normal_depth_dir + self.abnormal_depth_dir = abnormal_depth_dir + self.normal_test_depth_dir = normal_test_depth_dir + self.extensions = extensions + + self.samples = make_folder3d_dataset( + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + mask_dir=self.mask_dir, + normal_depth_dir=self.normal_depth_dir, + abnormal_depth_dir=self.abnormal_depth_dir, + normal_test_depth_dir=self.normal_test_depth_dir, + split=self.split, + extensions=self.extensions, + ) + + @property + def name(self) -> str: + """Name of the dataset. + + Folder3D dataset overrides the name property to provide a custom name. + """ + return self._name + + +class Folder3D(AnomalibDataModule): + """Folder DataModule. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. + normal_dir (str | Path): Name of the directory containing normal images. + root (str | Path | None): Path to the root folder containing normal and abnormal dirs. + Defaults to ``None``. + abnormal_dir (str | Path | None): Name of the directory containing abnormal images. + Defaults to ``abnormal``. + normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test + dataset. + Defaults to ``None``. + mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations. + Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Path to the directory containing + normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` + abnormal_depth_dir (str | Path | None, optional): Path to the directory containing + abnormal depth images for the test dataset. + normal_test_depth_dir (str | Path | None, optional): Path to the directory containing + normal depth images for the test dataset. Normal test images will be a split of `normal_dir` + if `None`. Defaults to None. + normal_split_ratio (float, optional): Ratio to split normal training images and add to the + test set in case test set doesn't contain any normal images. + Defaults to 0.2. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the + directory. Defaults to None. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``. + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed used during random subset splitting. + Defaults to ``None``. + """ + + def __init__( + self, + name: str, + normal_dir: str | Path, + root: str | Path, + abnormal_dir: str | Path | None = None, + normal_test_dir: str | Path | None = None, + mask_dir: str | Path | None = None, + normal_depth_dir: str | Path | None = None, + abnormal_depth_dir: str | Path | None = None, + normal_test_depth_dir: str | Path | None = None, + extensions: tuple[str] | None = None, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + self._name = name + self.task = TaskType(task) + self.root = Path(root) + self.normal_dir = normal_dir + self.abnormal_dir = abnormal_dir + self.normal_test_dir = normal_test_dir + self.mask_dir = mask_dir + self.normal_depth_dir = normal_depth_dir + self.abnormal_depth_dir = abnormal_depth_dir + self.normal_test_depth_dir = normal_test_depth_dir + self.extensions = extensions + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = Folder3DDataset( + name=self.name, + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + mask_dir=self.mask_dir, + normal_depth_dir=self.normal_depth_dir, + abnormal_depth_dir=self.abnormal_depth_dir, + normal_test_depth_dir=self.normal_test_depth_dir, + extensions=self.extensions, + ) + + self.test_data = Folder3DDataset( + name=self.name, + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + normal_depth_dir=self.normal_depth_dir, + abnormal_depth_dir=self.abnormal_depth_dir, + normal_test_depth_dir=self.normal_test_depth_dir, + mask_dir=self.mask_dir, + extensions=self.extensions, + ) + + @property + def name(self) -> str: + """Name of the datamodule. + + Folder3D datamodule overrides the name property to provide a custom name. + """ + return self._name diff --git a/anomalib/data/depth/mvtec_3d.py b/anomalib/data/depth/mvtec_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac09cf5e7645df868f12c8e5a7d399b1cb8bf67 --- /dev/null +++ b/anomalib/data/depth/mvtec_3d.py @@ -0,0 +1,302 @@ +"""MVTec 3D-AD Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for the MVTec 3D-AD dataset. + If the dataset is not on the file system, the script downloads and extracts the dataset and create PyTorch data + objects. + +License: + MVTec 3D-AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International + License (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + +Reference: + - Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly + Detection and Localization in: Proceedings of the 17th International Joint Conference on Computer Vision, + Imaging and Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022, DOI: 10.5220/ + 0010865000003124. +""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections.abc import Sequence +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DownloadInfo, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + validate_path, +) + +logger = logging.getLogger(__name__) + + +IMG_EXTENSIONS = [".png", ".PNG", ".tiff"] + +DOWNLOAD_INFO = DownloadInfo( + name="mvtec_3d", + url="https://www.mydrive.ch/shares/45920/dd1eb345346df066c63b5c95676b961b/download/428824485-1643285832" + "/mvtec_3d_anomaly_detection.tar.xz", + hashsum="d8bb2800fbf3ac88e798da6ae10dc819", +) + +CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire") + + +def make_mvtec_3d_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] | None = None, +) -> DataFrame: + """Create MVTec 3D-AD samples by parsing the MVTec AD data file structure. + + The files are expected to follow this structure: + - `path/to/dataset/split/category/image_filename.png` + - `path/to/dataset/ground_truth/category/mask_filename.png` + + This function creates a DataFrame to store the parsed information. The DataFrame follows this format: + + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + + Args: + root (Path): Path to the dataset. + split (str | Split | None, optional): Dataset split (e.g., 'train' or 'test'). + Defaults to ``None``. + extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset. + Defaults to ``None``. + + Examples: + The following example shows how to get training samples from the MVTec 3D-AD 'bagel' category: + + >>> from pathlib import Path + >>> root = Path('./MVTec3D') + >>> category = 'bagel' + >>> path = root / category + >>> print(path) + PosixPath('MVTec3D/bagel') + + >>> samples = create_mvtec_3d_ad_samples(path, split='train') + >>> print(samples.head()) + path split label image_path mask_path label_index + MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/105.png MVTec3D/bagel/ground_truth/good/gt/105.png 0 + MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/017.png MVTec3D/bagel/ground_truth/good/gt/017.png 0 + + Returns: + DataFrame: An output DataFrame containing the samples of the dataset. + """ + if extensions is None: + extensions = IMG_EXTENSIONS + + root = validate_path(root) + samples_list = [(str(root),) + f.parts[-4:] for f in root.glob(r"**/*") if f.suffix in extensions] + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + samples = DataFrame(samples_list, columns=["path", "split", "label", "type", "file_name"]) + + # Modify image_path column by converting to absolute path + samples.loc[(samples.type == "rgb"), "image_path"] = ( + samples.path + "/" + samples.split + "/" + samples.label + "/" + "rgb/" + samples.file_name + ) + samples.loc[(samples.type == "rgb"), "depth_path"] = ( + samples.path + + "/" + + samples.split + + "/" + + samples.label + + "/" + + "xyz/" + + samples.file_name.str.split(".").str[0] + + ".tiff" + ) + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # separate masks from samples + mask_samples = samples.loc[((samples.split == "test") & (samples.type == "rgb"))].sort_values( + by="image_path", + ignore_index=True, + ) + samples = samples.sort_values(by="image_path", ignore_index=True) + + # assign mask paths to all test images + samples.loc[((samples.split == "test") & (samples.type == "rgb")), "mask_path"] = ( + mask_samples.path + "/" + samples.split + "/" + samples.label + "/" + "gt/" + samples.file_name + ) + samples = samples.dropna(subset=["image_path"]) + samples = samples.astype({"image_path": "str", "mask_path": "str", "depth_path": "str"}) + + # assert that the right mask files are associated with the right test images + mismatch_masks = ( + samples.loc[samples.label_index == LabelName.ABNORMAL] + .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) + .all() + ) + if not mismatch_masks: + msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files + in 'ground_truth' folder follow the same naming convention as the anomalous images in + the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png').""" + raise MisMatchError(msg) + + mismatch_depth = ( + samples.loc[samples.label_index == LabelName.ABNORMAL] + .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1) + .all() + ) + if not mismatch_depth: + msg = """Mismatch between anomalous images and depth images. Make sure the mask files in + 'xyz' folder follow the same naming convention as the anomalous images in the dataset + (e.g. image: '000.png', depth: '000.tiff').""" + raise MisMatchError(msg) + + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class MVTec3DDataset(AnomalibDepthDataset): + """MVTec 3D dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation`` + root (Path | str): Path to the root of the dataset + Defaults to ``"./datasets/MVTec3D"``. + category (str): Sub-category of the dataset, e.g. 'bagel' + Defaults to ``"bagel"``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + Defaults to ``None``. + """ + + def __init__( + self, + task: TaskType, + root: Path | str = "./datasets/MVTec3D", + category: str = "bagel", + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / Path(category) + self.split = split + self.samples = make_mvtec_3d_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS) + + +class MVTec3D(AnomalibDataModule): + """MVTec Datamodule. + + Args: + root (Path | str): Path to the root of the dataset + Defaults to ``"./datasets/MVTec3D"``. + category (str): Category of the MVTec dataset (e.g. "bottle" or "cable"). + Defaults to ``bagel``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec3D", + category: str = "bagel", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.category = category + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = MVTec3DDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + category=self.category, + ) + self.test_data = MVTec3DDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available.""" + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) diff --git a/anomalib/data/errors.py b/anomalib/data/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..97c956663c2fec4629a3146f3d355efcd4008065 --- /dev/null +++ b/anomalib/data/errors.py @@ -0,0 +1,19 @@ +"""Custom Exception Class for Mismatch Detection (MisMatchError).""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +class MisMatchError(Exception): + """Exception raised when a mismatch is detected. + + Attributes: + message (str): Explanation of the error. + """ + + def __init__(self, message: str = "") -> None: + if message: + self.message = message + else: + self.message = "Mismatch detected." + super().__init__(self.message) diff --git a/anomalib/data/image/__init__.py b/anomalib/data/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4db05b51df3ab6c7cf23698ff92c53588eb4d1ad --- /dev/null +++ b/anomalib/data/image/__init__.py @@ -0,0 +1,33 @@ +"""Anomalib Image Datasets. + +This module contains the supported image datasets for Anomalib. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +from .btech import BTech +from .folder import Folder +from .kolektor import Kolektor +from .mvtec import MVTec +from .mvtec_loco import MVTecLoco +from .visa import Visa + + +class ImageDataFormat(str, Enum): + """Supported Image Dataset Types.""" + + MVTEC = "mvtec" + MVTEC_3D = "mvtec_3d" + MVTEC_LOCO = "mvtec_loco" + BTECH = "btech" + KOLEKTOR = "kolektor" + FOLDER = "folder" + FOLDER_3D = "folder_3d" + VISA = "visa" + + +__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"] diff --git a/anomalib/data/image/btech.py b/anomalib/data/image/btech.py new file mode 100644 index 0000000000000000000000000000000000000000..33bbd68c4ca38989f842af75392c94e966882285 --- /dev/null +++ b/anomalib/data/image/btech.py @@ -0,0 +1,362 @@ +"""BTech Dataset. + +This script contains PyTorch Lightning DataModule for the BTech dataset. + +If the dataset is not on the file system, the script downloads and +extracts the dataset and create PyTorch data objects. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import shutil +from pathlib import Path + +import cv2 +import pandas as pd +from pandas.core.frame import DataFrame +from torchvision.transforms.v2 import Transform +from tqdm import tqdm + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import ( + DownloadInfo, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + validate_path, +) + +logger = logging.getLogger(__name__) + +DOWNLOAD_INFO = DownloadInfo( + name="btech", + url="https://avires.dimi.uniud.it/papers/btad/btad.zip", + hashsum="461c9387e515bfed41ecaae07c50cf6b10def647b36c9e31d239ab2736b10d2a", +) + +CATEGORIES = ("01", "02", "03") + + +def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFrame: + """Create BTech samples by parsing the BTech data file structure. + + The files are expected to follow the structure: + + .. code-block:: bash + + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/mask_filename.png + + Args: + path (Path): Path to dataset + split (str | Split | None, optional): Dataset split (ie., either train or test). + Defaults to ``None``. + + Example: + The following example shows how to get training samples from BTech 01 category: + + .. code-block:: python + + >>> root = Path('./BTech') + >>> category = '01' + >>> path = root / category + >>> path + PosixPath('BTech/01') + + >>> samples = make_btech_dataset(path, split='train') + >>> samples.head() + path split label image_path mask_path label_index + 0 BTech/01 train 01 BTech/01/train/ok/105.bmp BTech/01/ground_truth/ok/105.png 0 + 1 BTech/01 train 01 BTech/01/train/ok/017.bmp BTech/01/ground_truth/ok/017.png 0 + ... + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + """ + path = validate_path(path) + + samples_list = [ + (str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in (".bmp", ".png") + ] + if not samples_list: + msg = f"Found 0 images in {path}" + raise RuntimeError(msg) + + samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) + samples = samples[samples.split != "ground_truth"] + + # Create mask_path column + # (safely handles cases where non-mask image_paths end with either .png or .bmp) + samples["mask_path"] = ( + samples.path + + "/ground_truth/" + + samples.label + + "/" + + samples.image_path.str.rstrip("png").str.rstrip(".").str.rstrip("bmp").str.rstrip(".") + + ".png" + ) + + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path + + # Good images don't have mask + samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = "" + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "ok"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "ok"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # Get the data frame for the split. + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class BTechDataset(AnomalibDataset): + """Btech Dataset class. + + Args: + root: Path to the BTech dataset + category: Name of the BTech category. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split: 'train', 'val' or 'test' + task: ``classification``, ``detection`` or ``segmentation`` + create_validation_set: Create a validation subset in addition to the train and test subsets + + Examples: + >>> from anomalib.data.image.btech import BTechDataset + >>> from anomalib.data.utils.transforms import get_transforms + >>> transform = get_transforms(image_size=256) + >>> dataset = BTechDataset( + ... task="classification", + ... transform=transform, + ... root='./datasets/BTech', + ... category='01', + ... ) + >>> dataset[0].keys() + >>> dataset.setup() + dict_keys(['image']) + + >>> dataset.split = "test" + >>> dataset[0].keys() + dict_keys(['image', 'image_path', 'label']) + + >>> dataset.task = "segmentation" + >>> dataset.split = "train" + >>> dataset[0].keys() + dict_keys(['image']) + + >>> dataset.split = "test" + >>> dataset[0].keys() + dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) + + >>> dataset[0]["image"].shape, dataset[0]["mask"].shape + (torch.Size([3, 256, 256]), torch.Size([256, 256])) + """ + + def __init__( + self, + root: str | Path, + category: str, + transform: Transform | None = None, + split: str | Split | None = None, + task: TaskType | str = TaskType.SEGMENTATION, + ) -> None: + super().__init__(task, transform) + + self.root_category = Path(root) / category + self.split = split + self.samples = make_btech_dataset(path=self.root_category, split=self.split) + + +class BTech(AnomalibDataModule): + """BTech Lightning Data Module. + + Args: + root (Path | str): Path to the BTech dataset. + Defaults to ``"./datasets/BTech"``. + category (str): Name of the BTech category. + Defaults to ``"01"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Eval batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType, optional): Task type. + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode, optional): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float, optional): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode, optional): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float, optional): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + + Examples: + To create the BTech datamodule, we need to instantiate the class, and call the ``setup`` method. + + >>> from anomalib.data import BTech + >>> datamodule = BTech( + ... root="./datasets/BTech", + ... category="01", + ... image_size=256, + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... transform_config_train=None, + ... transform_config_eval=None, + ... ) + >>> datamodule.setup() + + To get the train dataloader and the first batch of data: + + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image']) + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + To access the validation dataloader and the first batch of data: + + >>> i, data = next(enumerate(datamodule.val_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) + >>> data["image"].shape, data["mask"].shape + (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) + + Similarly, to access the test dataloader and the first batch of data: + + >>> i, data = next(enumerate(datamodule.test_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) + >>> data["image"].shape, data["mask"].shape + (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) + """ + + def __init__( + self, + root: Path | str = "./datasets/BTech", + category: str = "01", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.root = Path(root) + self.category = category + self.task = TaskType(task) + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = BTechDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + category=self.category, + ) + self.test_data = BTechDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = BTech(root="./datasets/BTech", category="01") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── BTech + ├── 01 + ├── 02 + └── 03 + """ + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root.parent, DOWNLOAD_INFO) + + # rename folder and convert images + logger.info("Renaming the dataset directory") + shutil.move(src=str(self.root.parent / "BTech_Dataset_transformed"), dst=str(self.root)) + logger.info("Convert the bmp formats to png to have consistent image extensions") + for filename in tqdm(self.root.glob("**/*.bmp"), desc="Converting bmp to png"): + image = cv2.imread(str(filename)) + cv2.imwrite(str(filename.with_suffix(".png")), image) + filename.unlink() diff --git a/anomalib/data/image/folder.py b/anomalib/data/image/folder.py new file mode 100644 index 0000000000000000000000000000000000000000..61f853e8c2460495f1f5f56a209a54364ae8e4aa --- /dev/null +++ b/anomalib/data/image/folder.py @@ -0,0 +1,478 @@ +"""Custom Folder Dataset. + +This script creates a custom dataset from a folder. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DirType, + LabelName, + Split, + TestSplitMode, + ValSplitMode, +) +from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path + + +def make_folder_dataset( + normal_dir: str | Path | Sequence[str | Path], + root: str | Path | None = None, + abnormal_dir: str | Path | Sequence[str | Path] | None = None, + normal_test_dir: str | Path | Sequence[str | Path] | None = None, + mask_dir: str | Path | Sequence[str | Path] | None = None, + split: str | Split | None = None, + extensions: tuple[str, ...] | None = None, +) -> DataFrame: + """Make Folder Dataset. + + Args: + normal_dir (str | Path | Sequence): Path to the directory containing normal images. + root (str | Path | None): Path to the root directory of the dataset. + Defaults to ``None``. + abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images. + Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing normal images for + the test dataset. Normal test images will be a split of `normal_dir` if `None`. + Defaults to ``None``. + mask_dir (str | Path | Sequence | None, optional): Path to the directory containing the mask annotations. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST). + Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + Defaults to ``None``. + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test). + + Examples: + Assume that we would like to use this ``make_folder_dataset`` to create a dataset from a folder. + We could then create the dataset as follows, + + .. code-block:: python + + folder_df = make_folder_dataset( + normal_dir=dataset_root / "good", + abnormal_dir=dataset_root / "crack", + split="train", + ) + folder_df.head() + + .. code-block:: bash + + image_path label label_index mask_path split + 0 ./toy/good/00.jpg DirType.NORMAL 0 Split.TRAIN + 1 ./toy/good/01.jpg DirType.NORMAL 0 Split.TRAIN + 2 ./toy/good/02.jpg DirType.NORMAL 0 Split.TRAIN + 3 ./toy/good/03.jpg DirType.NORMAL 0 Split.TRAIN + 4 ./toy/good/04.jpg DirType.NORMAL 0 Split.TRAIN + """ + + def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | None) -> list[Path]: + """Convert path to list of paths. + + Args: + path (str | Path | Sequence | None): Path to replace with Sequence[str | Path]. + + Examples: + >>> _resolve_path_and_convert_to_list("dir") + [Path("path/to/dir")] + >>> _resolve_path_and_convert_to_list(["dir1", "dir2"]) + [Path("path/to/dir1"), Path("path/to/dir2")] + + Returns: + list[Path]: The result of path replaced by Sequence[str | Path]. + """ + if isinstance(path, Sequence) and not isinstance(path, str): + return [validate_and_resolve_path(dir_path, root) for dir_path in path] + return [validate_and_resolve_path(path, root)] if path is not None else [] + + # All paths are changed to the List[Path] type and used. + normal_dir = _resolve_path_and_convert_to_list(normal_dir) + abnormal_dir = _resolve_path_and_convert_to_list(abnormal_dir) + normal_test_dir = _resolve_path_and_convert_to_list(normal_test_dir) + mask_dir = _resolve_path_and_convert_to_list(mask_dir) + if len(normal_dir) == 0: + msg = "A folder location must be provided in normal_dir." + raise ValueError(msg) + + filenames = [] + labels = [] + dirs = {DirType.NORMAL: normal_dir} + + if abnormal_dir: + dirs[DirType.ABNORMAL] = abnormal_dir + + if normal_test_dir: + dirs[DirType.NORMAL_TEST] = normal_test_dir + + if mask_dir: + dirs[DirType.MASK] = mask_dir + + for dir_type, paths in dirs.items(): + for path in paths: + filename, label = _prepare_files_labels(path, dir_type, extensions) + filenames += filename + labels += label + + samples = DataFrame({"image_path": filenames, "label": labels}) + samples = samples.sort_values(by="image_path", ignore_index=True) + + # Create label index for normal (0) and abnormal (1) images. + samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), + "label_index", + ] = LabelName.NORMAL + samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype("Int64") + + # If a path to mask is provided, add it to the sample dataframe. + + if len(mask_dir) > 0 and len(abnormal_dir) > 0: + samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[ + samples.label == DirType.MASK + ].image_path.to_numpy() + samples["mask_path"] = samples["mask_path"].fillna("") + samples = samples.astype({"mask_path": "str"}) + + # make sure all every rgb image has a corresponding mask image. + if not ( + samples.loc[samples.label_index == LabelName.ABNORMAL] + .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) + .all() + ): + msg = """Mismatch between anomalous images and mask images. Make sure the mask files " + "folder follow the same naming convention as the anomalous images in the dataset " + "(e.g. image: '000.png', mask: '000.png').""" + raise MisMatchError(msg) + + else: + samples["mask_path"] = "" + + # remove all the rows with temporal image samples that have already been assigned + samples = samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST) + ] + + # Ensure the pathlib objects are converted to str. + # This is because torch dataloader doesn't like pathlib. + samples = samples.astype({"image_path": "str"}) + + # Create train/test split. + # By default, all the normal samples are assigned as train. + # and all the abnormal samples are test. + samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN + samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST + + # Get the data frame for the split. + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class FolderDataset(AnomalibDataset): + """Folder dataset. + + This class is used to create a dataset from a folder. The class utilizes the Torch Dataset class. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. + task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``). + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + normal_dir (str | Path | Sequence): Path to the directory containing normal images. + root (str | Path | None): Root folder of the dataset. + Defaults to ``None``. + abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images. + Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing + normal images for the test dataset. + Defaults to ``None``. + mask_dir (str | Path | Sequence | None, optional): Path to the directory containing + the mask annotations. + Defaults to ``None``. + split (str | Split | None): Fixed subset split that follows from folder structure on file system. + Choose from [Split.FULL, Split.TRAIN, Split.TEST] + Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + Defaults to ``None``. + + Raises: + ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is + provided, `task` should be set to `segmentation`. + + Examples: + Assume that we would like to use this ``FolderDataset`` to create a dataset from a folder for a classification + task. We could first create the transforms, + + >>> from anomalib.data.utils import InputNormalizationMethod, get_transforms + >>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE) + + We could then create the dataset as follows, + + .. code-block:: python + + folder_dataset_classification_train = FolderDataset( + normal_dir=dataset_root / "good", + abnormal_dir=dataset_root / "crack", + split="train", + transform=transform, + task=TaskType.CLASSIFICATION, + ) + + """ + + def __init__( + self, + name: str, + task: TaskType, + normal_dir: str | Path | Sequence[str | Path], + transform: Transform | None = None, + root: str | Path | None = None, + abnormal_dir: str | Path | Sequence[str | Path] | None = None, + normal_test_dir: str | Path | Sequence[str | Path] | None = None, + mask_dir: str | Path | Sequence[str | Path] | None = None, + split: str | Split | None = None, + extensions: tuple[str, ...] | None = None, + ) -> None: + super().__init__(task, transform) + + self._name = name + self.split = split + self.root = root + self.normal_dir = normal_dir + self.abnormal_dir = abnormal_dir + self.normal_test_dir = normal_test_dir + self.mask_dir = mask_dir + self.extensions = extensions + + self.samples = make_folder_dataset( + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + mask_dir=self.mask_dir, + split=self.split, + extensions=self.extensions, + ) + + @property + def name(self) -> str: + """Name of the dataset. + + Folder dataset overrides the name property to provide a custom name. + """ + return self._name + + +class Folder(AnomalibDataModule): + """Folder DataModule. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. + normal_dir (str | Path | Sequence): Name of the directory containing normal images. + root (str | Path | None): Path to the root folder containing normal and abnormal dirs. + Defaults to ``None``. + abnormal_dir (str | Path | None | Sequence): Name of the directory containing abnormal images. + Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing + normal images for the test dataset. + Defaults to ``None``. + mask_dir (str | Path | Sequence | None, optional): Path to the directory containing + the mask annotations. + Defaults to ``None``. + normal_split_ratio (float, optional): Ratio to split normal training images and add to the + test set in case test set doesn't contain any normal images. + Defaults to 0.2. + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the + directory. + Defaults to ``None``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Validation, test and predict batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``. + Defaults to ``segmentation``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed used during random subset splitting. + Defaults to ``None``. + + Examples: + The following code demonstrates how to use the ``Folder`` datamodule. Assume that the dataset is structured + as follows: + + .. code-block:: bash + + $ tree sample_dataset + sample_dataset + ├── colour + │ ├── 00.jpg + │ ├── ... + │ └── x.jpg + ├── crack + │ ├── 00.jpg + │ ├── ... + │ └── y.jpg + ├── good + │ ├── ... + │ └── z.jpg + ├── LICENSE + └── mask + ├── colour + │ ├── ... + │ └── x.jpg + └── crack + ├── ... + └── y.jpg + + .. code-block:: python + + folder_datamodule = Folder( + root=dataset_root, + normal_dir="good", + abnormal_dir="crack", + task=TaskType.SEGMENTATION, + mask_dir=dataset_root / "mask" / "crack", + image_size=256, + normalization=InputNormalizationMethod.NONE, + ) + folder_datamodule.setup() + + To access the training images, + + .. code-block:: python + + >> i, data = next(enumerate(folder_datamodule.train_dataloader())) + >> print(data.keys(), data["image"].shape) + + To access the test images, + + .. code-block:: python + + >> i, data = next(enumerate(folder_datamodule.test_dataloader())) + >> print(data.keys(), data["image"].shape) + """ + + def __init__( + self, + name: str, + normal_dir: str | Path | Sequence[str | Path], + root: str | Path | None = None, + abnormal_dir: str | Path | Sequence[str | Path] | None = None, + normal_test_dir: str | Path | Sequence[str | Path] | None = None, + mask_dir: str | Path | Sequence[str | Path] | None = None, + normal_split_ratio: float = 0.2, + extensions: tuple[str] | None = None, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + self._name = name + self.root = root + self.normal_dir = normal_dir + self.abnormal_dir = abnormal_dir + self.normal_test_dir = normal_test_dir + self.mask_dir = mask_dir + self.task = TaskType(task) + self.extensions = extensions + test_split_mode = TestSplitMode(test_split_mode) + val_split_mode = ValSplitMode(val_split_mode) + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + seed=seed, + ) + + if task == TaskType.SEGMENTATION and test_split_mode == TestSplitMode.FROM_DIR and mask_dir is None: + msg = ( + f"Segmentation task requires mask directory if test_split_mode is {test_split_mode}. " + "You could set test_split_mode to {TestSplitMode.NONE} or provide a mask directory." + ) + raise ValueError( + msg, + ) + + self.normal_split_ratio = normal_split_ratio + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = FolderDataset( + name=self.name, + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + mask_dir=self.mask_dir, + extensions=self.extensions, + ) + + self.test_data = FolderDataset( + name=self.name, + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + normal_dir=self.normal_dir, + abnormal_dir=self.abnormal_dir, + normal_test_dir=self.normal_test_dir, + mask_dir=self.mask_dir, + extensions=self.extensions, + ) + + @property + def name(self) -> str: + """Name of the datamodule. + + Folder datamodule overrides the name property to provide a custom name. + """ + return self._name diff --git a/anomalib/data/image/kolektor.py b/anomalib/data/image/kolektor.py new file mode 100644 index 0000000000000000000000000000000000000000..049c770c45bd93f9737790ee2d60ee029d81482a --- /dev/null +++ b/anomalib/data/image/kolektor.py @@ -0,0 +1,342 @@ +"""Kolektor Surface-Defect Dataset (CC BY-NC-SA 4.0). + +Description: + This script provides a PyTorch Dataset, DataLoader, and PyTorch Lightning DataModule for the Kolektor + Surface-Defect dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset `_. + +License: + The Kolektor Surface-Defect dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike + 4.0 International License (CC BY-NC-SA 4.0). For more details, visit + `Creative Commons License `_. + +Reference: + Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. "Segmentation-based deep-learning approach + for surface-defect detection." Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776. +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path + +import numpy as np +from cv2 import imread +from pandas import DataFrame +from sklearn.model_selection import train_test_split +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DownloadInfo, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + validate_path, +) + +__all__ = ["Kolektor", "KolektorDataset", "make_kolektor_dataset"] + +logger = logging.getLogger(__name__) + +DOWNLOAD_INFO = DownloadInfo( + name="kolektor", + url="https://go.vicos.si/kolektorsdd", + hashsum="65dc621693418585de9c4467d1340ea7958a6181816f0dc2883a1e8b61f9d4dc", + filename="KolektorSDD.zip", +) + + +def is_mask_anomalous(path: str) -> int: + """Check if a mask shows defects. + + Args: + path (str): Path to the mask file. + + Returns: + int: 1 if the mask shows defects, 0 otherwise. + + Example: + Assume that the following image is a mask for a defective image. + Then the function will return 1. + + >>> from anomalib.data.image.kolektor import is_mask_anomalous + >>> path = './KolektorSDD/kos01/Part0_label.bmp' + >>> is_mask_anomalous(path) + 1 + """ + img_arr = imread(path) + if np.all(img_arr == 0): + return 0 + return 1 + + +def make_kolektor_dataset( + root: str | Path, + train_split_ratio: float = 0.8, + split: str | Split | None = None, +) -> DataFrame: + """Create Kolektor samples by parsing the Kolektor data file structure. + + The files are expected to follow this structure: + - Image files: `path/to/dataset/item/image_filename.jpg`, `path/to/dataset/kos01/Part0.jpg` + - Mask files: `path/to/dataset/item/mask_filename.bmp`, `path/to/dataset/kos01/Part0_label.bmp` + + This function creates a DataFrame to store the parsed information in the following format: + + +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ + | | path | item | split | label | image_path | mask_path | label_index | + +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ + | 0 | KolektorSDD | kos01 | test | Bad | /path/to/image_file | /path/to/mask_file | 1 | + +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ + + Args: + root (Path): Path to the dataset. + train_split_ratio (float, optional): Ratio for splitting good images into train/test sets. + Defaults to ``0.8``. + split (str | Split | None, optional): Dataset split (either 'train' or 'test'). + Defaults to ``None``. + + Returns: + pandas.DataFrame: An output DataFrame containing the samples of the dataset. + + Example: + The following example shows how to get training samples from the Kolektor Dataset: + + >>> from pathlib import Path + >>> root = Path('./KolektorSDD/') + >>> samples = create_kolektor_samples(root, train_split_ratio=0.8) + >>> samples.head() + path item split label image_path mask_path label_index + 0 KolektorSDD kos01 train Good KolektorSDD/kos01/Part0.jpg KolektorSDD/kos01/Part0_label.bmp 0 + 1 KolektorSDD kos01 train Good KolektorSDD/kos01/Part1.jpg KolektorSDD/kos01/Part1_label.bmp 0 + 2 KolektorSDD kos01 train Good KolektorSDD/kos01/Part2.jpg KolektorSDD/kos01/Part2_label.bmp 0 + 3 KolektorSDD kos01 test Good KolektorSDD/kos01/Part3.jpg KolektorSDD/kos01/Part3_label.bmp 0 + 4 KolektorSDD kos01 train Good KolektorSDD/kos01/Part4.jpg KolektorSDD/kos01/Part4_label.bmp 0 + """ + root = validate_path(root) + + # Get list of images and masks + samples_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".jpg"] + masks_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".bmp"] + + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + # Create dataframes + samples = DataFrame(samples_list, columns=["path", "item", "image_path"]) + masks = DataFrame(masks_list, columns=["path", "item", "image_path"]) + + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.item + "/" + samples.image_path + masks["image_path"] = masks.path + "/" + masks.item + "/" + masks.image_path + + # Sort samples by image path + samples = samples.sort_values(by="image_path", ignore_index=True) + masks = masks.sort_values(by="image_path", ignore_index=True) + + # Add mask paths for sample images + samples["mask_path"] = masks.image_path.to_numpy() + + # Use is_good func to configure the label_index + samples["label_index"] = samples["mask_path"].apply(is_mask_anomalous) + samples.label_index = samples.label_index.astype(int) + + # Use label indexes to label data + samples.loc[(samples.label_index == 0), "label"] = "Good" + samples.loc[(samples.label_index == 1), "label"] = "Bad" + + # Add all 'Bad' samples to test set + samples.loc[(samples.label == "Bad"), "split"] = "test" + + # Divide 'good' images to train/test on 0.8/0.2 ratio + train_samples, test_samples = train_test_split( + samples[samples.label == "Good"], + train_size=train_split_ratio, + random_state=42, + ) + samples.loc[train_samples.index, "split"] = "train" + samples.loc[test_samples.index, "split"] = "test" + + # Reorder columns + samples = samples[["path", "item", "split", "label", "image_path", "mask_path", "label_index"]] + + # assert that the right mask files are associated with the right test images + if not ( + samples.loc[samples.label_index == 1] + .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) + .all() + ): + msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files + follow the same naming convention as the anomalous images in the dataset + (e.g. image: 'Part0.jpg', mask: 'Part0_label.bmp').""" + raise MisMatchError(msg) + + # Get the dataframe for the required split + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class KolektorDataset(AnomalibDataset): + """Kolektor dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation`` + root (Path | str): Path to the root of the dataset + Defaults to ``./datasets/kolektor``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + Defaults to ``None``. + """ + + def __init__( + self, + task: TaskType, + root: Path | str = "./datasets/kolektor", + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root = root + self.split = split + self.samples = make_kolektor_dataset(self.root, train_split_ratio=0.8, split=self.split) + + +class Kolektor(AnomalibDataModule): + """Kolektor Datamodule. + + Args: + root (Path | str): Path to the root of the dataset + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR`` + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2`` + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST`` + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5`` + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + """ + + def __init__( + self, + root: Path | str = "./datasets/kolektor", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = KolektorDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + ) + self.test_data = KolektorDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = Kolektor(root="./datasets/kolektor") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── kolektor + ├── kolektorsdd + ├── kos01 + ├── ... + └── kos50 + ├── Part0.jpg + ├── Part0_label.bmp + └── ... + """ + if (self.root).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) diff --git a/anomalib/data/image/mvtec.py b/anomalib/data/image/mvtec.py new file mode 100644 index 0000000000000000000000000000000000000000..c2cdc69755e9b6730934854e5f409008060439ca --- /dev/null +++ b/anomalib/data/image/mvtec.py @@ -0,0 +1,414 @@ +"""MVTec AD Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch Lightning + DataModule for the MVTec AD dataset. If the dataset is not on the file system, + the script downloads and extracts the dataset and create PyTorch data objects. + +License: + MVTec AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + +References: + - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger: + The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for + Unsupervised Anomaly Detection; in: International Journal of Computer Vision + 129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4. + + - Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — + A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; + in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), + 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Sequence +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DownloadInfo, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + validate_path, +) + +logger = logging.getLogger(__name__) + + +IMG_EXTENSIONS = (".png", ".PNG") + +DOWNLOAD_INFO = DownloadInfo( + name="mvtec", + url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094" + "/mvtec_anomaly_detection.tar.xz", + hashsum="cf4313b13603bec67abb49ca959488f7eedce2a9f7795ec54446c649ac98cd3d", +) + +CATEGORIES = ( + "bottle", + "cable", + "capsule", + "carpet", + "grid", + "hazelnut", + "leather", + "metal_nut", + "pill", + "screw", + "tile", + "toothbrush", + "transistor", + "wood", + "zipper", +) + + +def make_mvtec_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] | None = None, +) -> DataFrame: + """Create MVTec AD samples by parsing the MVTec AD data file structure. + + The files are expected to follow the structure: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/mask_filename.png + + This function creates a dataframe to store the parsed information based on the following format: + + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | + +===+===============+=======+=========+===============+=======================================+=============+ + | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + + Args: + root (Path): Path to dataset + split (str | Split | None, optional): Dataset split (ie., either train or test). + Defaults to ``None``. + extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset. + Defaults to ``None``. + + Examples: + The following example shows how to get training samples from MVTec AD bottle category: + + >>> root = Path('./MVTec') + >>> category = 'bottle' + >>> path = root / category + >>> path + PosixPath('MVTec/bottle') + + >>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0) + >>> samples.head() + path split label image_path mask_path label_index + 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0 + 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0 + 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0 + 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0 + 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0 + + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + """ + if extensions is None: + extensions = IMG_EXTENSIONS + + root = validate_path(root) + samples_list = [(str(root),) + f.parts[-3:] for f in root.glob(r"**/*") if f.suffix in extensions] + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + samples = DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) + + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # separate masks from samples + mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True) + samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True) + + # assign mask paths to anomalous test images + samples["mask_path"] = "" + samples.loc[ + (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), + "mask_path", + ] = mask_samples.image_path.to_numpy() + + # assert that the right mask files are associated with the right test images + abnormal_samples = samples.loc[samples.label_index == LabelName.ABNORMAL] + if ( + len(abnormal_samples) + and not abnormal_samples.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1).all() + ): + msg = """Mismatch between anomalous images and ground truth masks. Make sure t + he mask files in 'ground_truth' folder follow the same naming convention as the + anomalous images in the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png').""" + raise MisMatchError(msg) + + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class MVTecDataset(AnomalibDataset): + """MVTec dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + root (Path | str): Path to the root of the dataset. + Defaults to ``./datasets/MVTec``. + category (str): Sub-category of the dataset, e.g. 'bottle' + Defaults to ``bottle``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + Defaults to ``None``. + + Examples: + .. code-block:: python + + from anomalib.data.image.mvtec import MVTecDataset + from anomalib.data.utils.transforms import get_transforms + + transform = get_transforms(image_size=256) + dataset = MVTecDataset( + task="classification", + transform=transform, + root='./datasets/MVTec', + category='zipper', + ) + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image']) + + When the task is segmentation, the dataset will also contain the mask: + + .. code-block:: python + + dataset.task = "segmentation" + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W). + + .. code-block:: python + + print(dataset[0]["image"].shape, dataset[0]["mask"].shape) + # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256])) + + """ + + def __init__( + self, + task: TaskType, + root: Path | str = "./datasets/MVTec", + category: str = "bottle", + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / Path(category) + self.category = category + self.split = split + self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS) + + +class MVTec(AnomalibDataModule): + """MVTec Datamodule. + + Args: + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/MVTec"``. + category (str): Category of the MVTec dataset (e.g. "bottle" or "cable"). + Defaults to ``"bottle"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defualts to ``None``. + + Examples: + To create an MVTec AD datamodule with default settings: + + >>> datamodule = MVTec() + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + To change the category of the dataset: + + >>> datamodule = MVTec(category="cable") + + To change the image and batch size: + + >>> datamodule = MVTec(image_size=(512, 512), train_batch_size=16, eval_batch_size=8) + + MVTec AD dataset does not provide a validation set. If you would like + to use a separate validation set, you can use the ``val_split_mode`` and + ``val_split_ratio`` arguments to create a validation set. + + >>> datamodule = MVTec(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1) + + This will subsample the test set by 10% and use it as the validation set. + If you would like to create a validation set synthetically that would + not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option. + + >>> datamodule = MVTec(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2) + + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec", + category: str = "bottle", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.category = category + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform dynamic subset splitting. + + This method may be overridden in subclass for custom splitting behaviour. + + Note: + The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule + subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset + splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and + the test set must therefore be created as early as the `fit` stage. + + """ + self.train_data = MVTecDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + category=self.category, + ) + self.test_data = MVTecDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = MVTec(root="./datasets/MVTec", category="bottle") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── MVTec + ├── bottle + ├── ... + └── zipper + """ + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) diff --git a/anomalib/data/image/mvtec_loco.py b/anomalib/data/image/mvtec_loco.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef2b4ab8e502c797e8ff7a0689a0178096baef1 --- /dev/null +++ b/anomalib/data/image/mvtec_loco.py @@ -0,0 +1,480 @@ +"""MVTec LOCO AD Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch Lightning + DataModule for the MVTec LOCO AD dataset. If the dataset is not on the file system, + the script downloads and extracts the dataset and create PyTorch data objects. + +License: + MVTec LOCO AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + +References: + - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, and Carsten Steger: + Beyond Dents and Scratches: Logical Constraints in Unsupervised Anomaly Detection and Localization; + in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Sequence +from pathlib import Path + +import torch +from pandas import DataFrame +from PIL import Image +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2.functional import to_image +from torchvision.tv_tensors import Mask + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import ( + DownloadInfo, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + masks_to_boxes, + read_image, + validate_path, +) + +logger = logging.getLogger(__name__) + + +IMG_EXTENSIONS = (".png", ".PNG") + +DOWNLOAD_INFO = DownloadInfo( + name="mvtec_loco", + url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701" + "/mvtec_loco_anomaly_detection.tar.xz", + hashsum="9e7c84dba550fd2e59d8e9e231c929c45ba737b6b6a6d3814100f54d63aae687", +) + +CATEGORIES = ( + "breakfast_box", + "juice_bottle", + "pushpins", + "screw_bag", + "splicing_connectors", +) + + +def make_mvtec_loco_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] = IMG_EXTENSIONS, +) -> DataFrame: + """Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure. + + The files are expected to follow the structure: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/image_filename/000.png + + where there can be multiple ground-truth masks for the corresponding anomalous images. + + This function creates a dataframe to store the parsed information based on the following format: + + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | + +===+===============+=======+=========+===============+=======================================+=============+ + | 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 | + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ + + Args: + root (str | Path): Path to dataset + split (str | Split | None): Dataset split (ie., either train or test). + Defaults to ``None``. + extensions (Sequence[str]): List of file extensions to be included in the dataset. + Defaults to ``None``. + + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + + Examples: + The following example shows how to get test samples from MVTec LOCO AD pushpins category: + + >>> root = Path('./MVTec_LOCO') + >>> category = 'pushpins' + >>> path = root / category + >>> samples = make_mvtec_loco_dataset(path, split='test') + """ + root = validate_path(root) + + # Retrieve the image and mask files + samples_list = [] + for f in root.glob("**/*"): + if f.suffix in extensions: + parts = f.parts + # 'ground_truth' and non 'ground_truth' path have a different structure + if "ground_truth" not in parts: + split_folder, label_folder, image_file = parts[-3:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, "", image_path)) + else: + split_folder, label_folder, image_folder, image_file = parts[-4:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, image_folder, image_path)) + + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + samples = DataFrame(samples_list, columns=["path", "split", "label", "image_folder", "image_path"]) + + # Replace validation to Split.VAL.value in the split column + samples["split"] = samples["split"].replace("validation", Split.VAL.value) + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # separate ground-truth masks from samples + mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True) + samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True) + + # Group masks and aggregate the path into a list + mask_samples = ( + mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"] + .agg(list) + .reset_index() + .rename(columns={"image_path": "mask_path"}) + ) + + # assign mask paths to anomalous test images + samples["mask_path"] = "" + samples.loc[ + (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), + "mask_path", + ] = mask_samples.mask_path.to_numpy() + + # validate that the right mask files are associated with the right test images + if len(samples.loc[samples.label_index == LabelName.ABNORMAL]): + image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem) + mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply( + lambda x: {Path(mask_path).parent.stem for mask_path in x}, + ) + + if not all( + next(iter(mask_stems)) == image_stem + for image_stem, mask_stems in zip(image_stems, mask_parent_stems, strict=True) + ): + error_message = ( + "Mismatch between anomalous images and ground truth masks. " + "Make sure the parent folder of the mask files in 'ground_truth' folder " + "follows the same naming convention as the anomalous images in the dataset " + "(e.g., image: '005.png', mask: '005/000.png')." + ) + raise ValueError(error_message) + + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class MVTecLocoDataset(AnomalibDataset): + """MVTec LOCO dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + root (Path | str): Path to the root of the dataset. + Defaults to ``./datasets/MVTec_LOCO``. + category (str): Sub-category of the dataset, e.g. 'breakfast_box' + Defaults to ``breakfast_box``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, Split.TRAIN, Split.VAL, or Split.TEST + Defaults to ``None``. + + Examples: + .. code-block:: python + + from anomalib.data.image.mvtec_loco import MVTecLocoDataset + from anomalib.data.utils.transforms import get_transforms + from torchvision.transforms.v2 import Resize + + transform = Resize((256, 256)) + dataset = MVTecLocoDataset( + task="classification", + transform=transform, + root='./datasets/MVTec_LOCO', + category='breakfast_box', + ) + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image']) + + When the task is segmentation, the dataset will also contain the mask: + + .. code-block:: python + + dataset.task = "segmentation" + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W). + + .. code-block:: python + + print(dataset[0]["image"].shape, dataset[0]["mask"].shape) + # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256])) + """ + + def __init__( + self, + task: TaskType, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / category + self.split = split + self.samples = make_mvtec_loco_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + ) + + @staticmethod + def _read_mask(mask_path: str | Path) -> Mask: + image = Image.open(mask_path).convert("L") + return Mask(to_image(image).squeeze(), dtype=torch.uint8) + + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Get dataset item for the index ``index``. + + This method is mostly based on the super class implementation, with some different as follows: + - Using 'torch.where' to make sure the 'mask' in the return item is binarized + - An additional 'masks' is added, the non-binary masks with original size for the SPRO metric calculation + Args: + index (int): Index to get the item. + + Returns: + dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path, + target path, image tensor, label and transformed bounding box. + """ + image_path = self.samples.iloc[index].image_path + mask_path = self.samples.iloc[index].mask_path + label_index = self.samples.iloc[index].label_index + + image = read_image(image_path, as_tensor=True) + item = {"image_path": image_path, "label": label_index} + + if self.task == TaskType.CLASSIFICATION: + item["image"] = self.transform(image) if self.transform else image + elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + if isinstance(mask_path, str): + mask_path = [mask_path] + semantic_mask = ( + Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) + if label_index == LabelName.NORMAL + else Mask(torch.stack([self._read_mask(path) for path in mask_path])) + ) + + binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8)) + item["image"], item["mask"] = self.transform(image, binary_mask) if self.transform else (image, binary_mask) + + item["mask_path"] = mask_path + # List of masks with the original size for saturation based metrics calculation + item["semantic_mask"] = semantic_mask + + if self.task == TaskType.DETECTION: + # create boxes from masks for detection task + boxes, _ = masks_to_boxes(item["mask"]) + item["boxes"] = boxes[0] + else: + msg = f"Unknown task type: {self.task}" + raise ValueError(msg) + + return item + + +class MVTecLoco(AnomalibDataModule): + """MVTec LOCO Datamodule. + + Args: + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/MVTec_LOCO"``. + category (str): Category of the MVTec LOCO dataset (e.g. "breakfast_box"). + Defaults to ``"breakfast_box"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_DIR``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + + Examples: + To create an MVTec LOCO AD datamodule with default settings: + + >>> datamodule = MVTecLoco(root="anomalib/datasets/MVTec_LOCO") + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + To change the category of the dataset: + + >>> datamodule = MVTecLoco(category="pushpins") + + To change the image and batch size: + + >>> datamodule = MVTecLoco(image_size=(512, 512), train_batch_size=16, eval_batch_size=8) + + MVTec LOCO AD dataset provide an independent validation set with normal images only in the 'validation' folder. + If you would like to use a different validation set splitted from train or test set, + you can use the ``val_split_mode`` and ``val_split_ratio`` arguments to create a new validation set. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1) + + This will subsample the test set by 10% and use it as the validation set. + If you would like to create a validation set synthetically that would + not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2) + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode = ValSplitMode.FROM_DIR, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + self.task = task + self.root = Path(root) + self.category = category + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets, configs, and perform dynamic subset splitting. + + This method overrides the parent class's method to also setup the val dataset. + The MVTec LOCO dataset provides an independent validation subset. + """ + self.train_data = MVTecLocoDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + category=self.category, + ) + self.val_data = MVTecLocoDataset( + task=self.task, + transform=self.eval_transform, + split=Split.VAL, + root=self.root, + category=self.category, + ) + self.test_data = MVTecLocoDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = MVTecLoco(root="./datasets/MVTec_LOCO", category="breakfast_box") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── MVTec_LOCO + ├── breakfast_box + ├── ... + └── splicing_connectors + """ + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) diff --git a/anomalib/data/image/visa.py b/anomalib/data/image/visa.py new file mode 100644 index 0000000000000000000000000000000000000000..d732e7c2be5c049fae8ea0e9d3f69d7e6606489c --- /dev/null +++ b/anomalib/data/image/visa.py @@ -0,0 +1,364 @@ +"""Visual Anomaly (VisA) Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch + Lightning DataModule for the Visual Anomal (VisA) dataset. + If the dataset is not on the file system, the script downloads and + extracts the dataset and create PyTorch data objects. +License: + The VisA dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). +Reference: + - Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). SPot-the-Difference + Self-supervised Pre-training for Anomaly Detection and Segmentation. In European + Conference on Computer Vision (pp. 392-408). Springer, Cham. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Subset splitting code adapted from https://github.com/amazon-science/spot-diff +# Original licence: Apache-2.0 + + +import csv +import logging +import shutil +from pathlib import Path + +import cv2 +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import ( + DownloadInfo, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, +) + +from .mvtec import make_mvtec_dataset + +logger = logging.getLogger(__name__) + +EXTENSIONS = (".png", ".jpg", ".JPG") + +DOWNLOAD_INFO = DownloadInfo( + name="VisA", + url="https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar", + hashsum="2eb8690c803ab37de0324772964100169ec8ba1fa3f7e94291c9ca673f40f362", +) + +CATEGORIES = ( + "candle", + "capsules", + "cashew", + "chewinggum", + "fryum", + "macaroni1", + "macaroni2", + "pcb1", + "pcb2", + "pcb3", + "pcb4", + "pipe_fryum", +) + + +class VisaDataset(AnomalibDataset): + """VisA dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation`` + root (str | Path): Path to the root of the dataset + category (str): Sub-category of the dataset, e.g. 'candle' + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + Defaults to ``None``. + + Examples: + To create a Visa dataset for classification: + + .. code-block:: python + + from anomalib.data.image.visa import VisaDataset + from anomalib.data.utils.transforms import get_transforms + + transform = get_transforms(image_size=256) + dataset = VisaDataset( + task="classification", + transform=transform, + split="train", + root="./datasets/visa/visa_pytorch/", + category="candle", + ) + dataset.setup() + dataset[0].keys() + + # Output + dict_keys(['image_path', 'label', 'image']) + + If you want to use the dataset for segmentation, you can use the same + code as above, with the task set to ``segmentation``. The dataset will + then have a ``mask`` key in the output dictionary. + + .. code-block:: python + + from anomalib.data.image.visa import VisaDataset + from anomalib.data.utils.transforms import get_transforms + + transform = get_transforms(image_size=256) + dataset = VisaDataset( + task="segmentation", + transform=transform, + split="train", + root="./datasets/visa/visa_pytorch/", + category="candle", + ) + dataset.setup() + dataset[0].keys() + + # Output + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + """ + + def __init__( + self, + task: TaskType, + root: str | Path, + category: str, + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / category + self.split = split + self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=EXTENSIONS) + + +class Visa(AnomalibDataModule): + """VisA Datamodule. + + Args: + root (Path | str): Path to the root of the dataset + Defaults to ``"./datasets/visa"``. + category (str): Category of the Visa dataset such as ``candle``. + Defaults to ``"candle"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defatuls to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + """ + + def __init__( + self, + root: Path | str = "./datasets/visa", + category: str = "capsules", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.split_root = self.root / "visa_pytorch" + self.category = category + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = VisaDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.split_root, + category=self.category, + ) + self.test_data = VisaDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.split_root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = Visa() + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── visa + ├── candle + ├── ... + ├── pipe_fryum + │ ├── Data + │ └── image_anno.csv + ├── split_csv + │ ├── 1cls.csv + │ ├── 2cls_fewshot.csv + │ └── 2cls_highshot.csv + ├── VisA_20220922.tar + └── visa_pytorch + ├── candle + ├── ... + ├── pcb4 + └── pipe_fryum + + ``prepare_data`` ensures that the dataset is converted to MVTec + format. ``visa_pytorch`` is the directory that contains the dataset + in the MVTec format. ``visa`` is the directory that contains the + original dataset. + """ + if (self.split_root / self.category).is_dir(): + # dataset is available, and split has been applied + logger.info("Found the dataset and train/test split.") + elif (self.root / self.category).is_dir(): + # dataset is available, but split has not yet been applied + logger.info("Found the dataset. Applying train/test split.") + self.apply_cls1_split() + else: + # dataset is not available + download_and_extract(self.root, DOWNLOAD_INFO) + logger.info("Downloaded the dataset. Applying train/test split.") + self.apply_cls1_split() + + def apply_cls1_split(self) -> None: + """Apply the 1-class subset splitting using the fixed split in the csv file. + + adapted from https://github.com/amazon-science/spot-diff + """ + logger.info("preparing data") + categories = [ + "candle", + "capsules", + "cashew", + "chewinggum", + "fryum", + "macaroni1", + "macaroni2", + "pcb1", + "pcb2", + "pcb3", + "pcb4", + "pipe_fryum", + ] + + split_file = self.root / "split_csv" / "1cls.csv" + + for category in categories: + train_folder = self.split_root / category / "train" + test_folder = self.split_root / category / "test" + mask_folder = self.split_root / category / "ground_truth" + + train_img_good_folder = train_folder / "good" + test_img_good_folder = test_folder / "good" + test_img_bad_folder = test_folder / "bad" + test_mask_bad_folder = mask_folder / "bad" + + train_img_good_folder.mkdir(parents=True, exist_ok=True) + test_img_good_folder.mkdir(parents=True, exist_ok=True) + test_img_bad_folder.mkdir(parents=True, exist_ok=True) + test_mask_bad_folder.mkdir(parents=True, exist_ok=True) + + with split_file.open(encoding="utf-8") as file: + csvreader = csv.reader(file) + next(csvreader) + for row in csvreader: + category, split, label, image_path, mask_path = row + label = "good" if label == "normal" else "bad" + image_name = image_path.split("/")[-1] + mask_name = mask_path.split("/")[-1] + + img_src_path = self.root / image_path + msk_src_path = self.root / mask_path + img_dst_path = self.split_root / category / split / label / image_name + msk_dst_path = self.split_root / category / "ground_truth" / label / mask_name + + shutil.copyfile(img_src_path, img_dst_path) + if split == "test" and label == "bad": + mask = cv2.imread(str(msk_src_path)) + + # binarize mask + mask[mask != 0] = 255 + + cv2.imwrite(str(msk_dst_path), mask) diff --git a/anomalib/data/predict.py b/anomalib/data/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..148e34a49795b2691b68580a20535b41e14ac551 --- /dev/null +++ b/anomalib/data/predict.py @@ -0,0 +1,52 @@ +"""Inference Dataset.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path +from typing import Any + +from torch.utils.data.dataset import Dataset +from torchvision.transforms.v2 import Transform + +from anomalib.data.utils import get_image_filenames, read_image + + +class PredictDataset(Dataset): + """Inference Dataset to perform prediction. + + Args: + path (str | Path): Path to an image or image-folder. + transform (A.Compose | None, optional): Transform object describing the transforms that are + applied to the inputs. + image_size (int | tuple[int, int] | None, optional): Target image size + to resize the original image. Defaults to None. + """ + + def __init__( + self, + path: str | Path, + transform: Transform | None = None, + image_size: int | tuple[int, int] = (256, 256), + ) -> None: + super().__init__() + + self.image_filenames = get_image_filenames(path) + self.transform = transform + self.image_size = image_size + + def __len__(self) -> int: + """Get the number of images in the given path.""" + return len(self.image_filenames) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Get the image based on the `index`.""" + image_filename = self.image_filenames[index] + image = read_image(image_filename, as_tensor=True) + if self.transform: + image = self.transform(image) + pre_processed = {"image": image} + pre_processed["image_path"] = str(image_filename) + + return pre_processed diff --git a/anomalib/data/transforms/__init__.py b/anomalib/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..146fb19e15f5b349941a665385f7f9bef6c77f99 --- /dev/null +++ b/anomalib/data/transforms/__init__.py @@ -0,0 +1,8 @@ +"""Custom input transforms for Anomalib.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .center_crop import ExportableCenterCrop + +__all__ = ["ExportableCenterCrop"] diff --git a/anomalib/data/transforms/center_crop.py b/anomalib/data/transforms/center_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..88b8655aae8c4b5f320d275cea942d9513b6714e --- /dev/null +++ b/anomalib/data/transforms/center_crop.py @@ -0,0 +1,87 @@ +"""Custom Torchvision transforms for Anomalib.""" + +# Original Code +# Copyright (c) Soumith Chintala 2016 +# https://github.com/pytorch/vision/blob/v0.16.1/torchvision/transforms/v2/functional/_geometry.py +# SPDX-License-Identifier: BSD-3-Clause +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from torch.nn.functional import pad +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2.functional._geometry import ( + _center_crop_compute_padding, + _center_crop_parse_output_size, + _parse_pad_padding, +) + + +def _center_crop_compute_crop_anchor( + crop_height: int, + crop_width: int, + image_height: int, + image_width: int, +) -> tuple[int, int]: + """Compute the anchor point for center-cropping. + + This function is a modified version of the torchvision.transforms.functional._center_crop_compute_crop_anchor + function. The original function uses `round` to compute the anchor point, which is not compatible with ONNX. + + Args: + crop_height (int): Desired height of the crop. + crop_width (int): Desired width of the crop. + image_height (int): Height of the input image. + image_width (int): Width of the input image. + """ + crop_top = torch.tensor((image_height - crop_height) / 2.0).round().int().item() + crop_left = torch.tensor((image_width - crop_width) / 2.0).round().int().item() + return crop_top, crop_left + + +def center_crop_image(image: torch.Tensor, output_size: list[int]) -> torch.Tensor: + """Apply center-cropping to an input image. + + Uses the modified anchor point computation function to compute the anchor point for center-cropping. + + Args: + image (torch.Tensor): Input image to be center-cropped. + output_size (list[int]): Desired output size of the crop. + """ + crop_height, crop_width = _center_crop_parse_output_size(output_size) + shape = image.shape + if image.numel() == 0: + return image.reshape(shape[:-2] + (crop_height, crop_width)) + image_height, image_width = shape[-2:] + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = pad(image, _parse_pad_padding(padding_ltrb), value=0.0) + + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] + + +class ExportableCenterCrop(Transform): + """Transform that applies center-cropping to an input image and allows to be exported to ONNX. + + Args: + size (int | tuple[int, int]): Desired output size of the crop. + """ + + def __init__(self, size: int | tuple[int, int]) -> None: + super().__init__() + self.size = list(size) if isinstance(size, tuple) else [size, size] + + def _transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor: + """Apply the transform.""" + del params + return center_crop_image(inpt, output_size=self.size) diff --git a/anomalib/data/utils/__init__.py b/anomalib/data/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e75ba5bf49a810892745e68cb294d05410964fbd --- /dev/null +++ b/anomalib/data/utils/__init__.py @@ -0,0 +1,56 @@ +"""Helper utilities for data.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .augmenter import Augmenter +from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes +from .download import DownloadInfo, download_and_extract +from .generators import random_2d_perlin +from .image import ( + generate_output_image_filename, + get_image_filenames, + get_image_height_and_width, + read_depth_image, + read_image, + read_mask, +) +from .label import LabelName +from .path import ( + DirType, + _check_and_convert_path, + _prepare_files_labels, + resolve_path, + validate_and_resolve_path, + validate_path, +) +from .split import Split, TestSplitMode, ValSplitMode, concatenate_datasets, random_split, split_by_label + +__all__ = [ + "generate_output_image_filename", + "get_image_filenames", + "get_image_height_and_width", + "random_2d_perlin", + "read_image", + "read_mask", + "read_depth_image", + "random_split", + "split_by_label", + "concatenate_datasets", + "Split", + "ValSplitMode", + "TestSplitMode", + "LabelName", + "DirType", + "Augmenter", + "masks_to_boxes", + "boxes_to_masks", + "boxes_to_anomaly_maps", + "download_and_extract", + "DownloadInfo", + "_check_and_convert_path", + "_prepare_files_labels", + "resolve_path", + "validate_path", + "validate_and_resolve_path", +] diff --git a/anomalib/data/utils/augmenter.py b/anomalib/data/utils/augmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..48338e0aa8b3d2828a0f2f2b6fd9de8e29dc31d4 --- /dev/null +++ b/anomalib/data/utils/augmenter.py @@ -0,0 +1,172 @@ +"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import math +import random +from pathlib import Path + +import cv2 +import imgaug.augmenters as iaa +import numpy as np +import torch +from PIL import Image +from torchvision.datasets.folder import IMG_EXTENSIONS + +from anomalib.data.utils.generators.perlin import random_2d_perlin + + +def nextpow2(value: int) -> int: + """Return the smallest power of 2 greater than or equal to the input value.""" + return 2 ** (math.ceil(math.log(value, 2))) + + +class Augmenter: + """Class that generates noisy augmentations of input images. + + Args: + anomaly_source_path (str | None): Path to a folder of images that will be used as source of the anomalous + noise. If not specified, random noise will be used instead. + p_anomalous (float): Probability that the anomalous perturbation will be applied to a given image. + beta (float): Parameter that determines the opacity of the noise mask. + """ + + def __init__( + self, + anomaly_source_path: str | None = None, + p_anomalous: float = 0.5, + beta: float | tuple[float, float] = (0.2, 1.0), + ) -> None: + self.p_anomalous = p_anomalous + self.beta = beta + + self.anomaly_source_paths: list[Path] = [] + if anomaly_source_path is not None: + for img_ext in IMG_EXTENSIONS: + self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext)) + + self.augmenters = [ + iaa.GammaContrast((0.5, 2.0), per_channel=True), + iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)), + iaa.pillike.EnhanceSharpness(), + iaa.AddToHueAndSaturation((-50, 50), per_channel=True), + iaa.Solarize(0.5, threshold=(32, 128)), + iaa.Posterize(), + iaa.Invert(), + iaa.pillike.Autocontrast(), + iaa.pillike.Equalize(), + iaa.Affine(rotate=(-45, 45)), + ] + self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) + + def rand_augmenter(self) -> iaa.Sequential: + """Select 3 random transforms that will be applied to the anomaly source images. + + Returns: + A selection of 3 transforms. + """ + aug_ind = np.random.default_rng().choice(np.arange(len(self.augmenters)), 3, replace=False) + return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]]) + + def generate_perturbation( + self, + height: int, + width: int, + anomaly_source_path: Path | str | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Generate an image containing a random anomalous perturbation using a source image. + + Args: + height (int): height of the generated image. + width: (int): width of the generated image. + anomaly_source_path (Path | str | None): Path to an image file. If not provided, random noise will be used + instead. + + Returns: + Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. + """ + # Generate random perlin noise + perlin_scale = 6 + min_perlin_scale = 0 + + perlin_scalex = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) + perlin_scaley = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) + + perlin_noise = random_2d_perlin((nextpow2(height), nextpow2(width)), (perlin_scalex, perlin_scaley))[ + :height, + :width, + ] + perlin_noise = self.rot(image=perlin_noise) + + # Create mask from perlin noise + mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) + mask = np.expand_dims(mask, axis=2).astype(np.float32) + + # Load anomaly source image + if anomaly_source_path: + anomaly_source_img = np.array(Image.open(anomaly_source_path)) + anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height)) + else: # if no anomaly source is specified, we use the perlin noise as anomalous source + anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2) + anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8) + + # Augment anomaly source image + aug = self.rand_augmenter() + anomaly_img_augmented = aug(image=anomaly_source_img) + + # Create anomalous perturbation that we will apply to the image + perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0 + + return perturbation, mask + + def augment_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Generate anomalous augmentations for a batch of input images. + + Args: + batch (torch.Tensor): Batch of input images + + Returns: + - Augmented image to which anomalous perturbations have been added. + - Ground truth masks corresponding to the anomalous perturbations. + """ + batch_size, channels, height, width = batch.shape + + # Collect perturbations + perturbations_list = [] + masks_list = [] + for _ in range(batch_size): + if torch.rand(1) > self.p_anomalous: # include normal samples + perturbations_list.append(torch.zeros((channels, height, width))) + masks_list.append(torch.zeros((1, height, width))) + else: + anomaly_source_path = ( + random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None + ) + perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path) + perturbations_list.append(torch.Tensor(perturbation).permute((2, 0, 1))) + masks_list.append(torch.Tensor(mask).permute((2, 0, 1))) + + perturbations = torch.stack(perturbations_list).to(batch.device) + masks = torch.stack(masks_list).to(batch.device) + + # Apply perturbations batch wise + if isinstance(self.beta, float): + beta = self.beta + elif isinstance(self.beta, tuple): + beta = torch.rand(batch_size) * (self.beta[1] - self.beta[0]) + self.beta[0] + beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) # type: ignore[attr-defined] + else: + msg = "Beta must be either float or tuple of floats" + raise TypeError(msg) + + augmented_batch = batch * (1 - masks) + (beta) * perturbations + (1 - beta) * batch * (masks) + + return augmented_batch, masks diff --git a/anomalib/data/utils/boxes.py b/anomalib/data/utils/boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..5af6f6470bf9d5e07aeab9fdc9ae0795d1d387b4 --- /dev/null +++ b/anomalib/data/utils/boxes.py @@ -0,0 +1,117 @@ +"""Helper functions for processing bounding box detections and annotations.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +from anomalib.utils.cv import connected_components_cpu, connected_components_gpu + + +def masks_to_boxes( + masks: torch.Tensor, + anomaly_maps: torch.Tensor | None = None, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert a batch of segmentation masks to bounding box coordinates. + + Args: + masks (torch.Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W) + anomaly_maps (Tensor | None, optional): Anomaly maps of shape (B, 1, H, W), (B, H, W) or (H, W) which are + used to determine an anomaly score for the converted bounding boxes. + + Returns: + list[torch.Tensor]: A list of length B where each element is a tensor of shape (N, 4) + containing the bounding box coordinates of the objects in the masks in xyxy format. + list[torch.Tensor]: A list of length B where each element is a tensor of length (N) + containing an anomaly score for each of the converted boxes. + """ + height, width = masks.shape[-2:] + masks = masks.view((-1, 1, height, width)).float() # reshape to (B, 1, H, W) and cast to float + if anomaly_maps is not None: + anomaly_maps = anomaly_maps.view((-1,) + masks.shape[-2:]) + + if masks.is_cpu: + batch_comps = connected_components_cpu(masks).squeeze(1) + else: + batch_comps = connected_components_gpu(masks).squeeze(1) + + batch_boxes = [] + batch_scores = [] + for im_idx, im_comps in enumerate(batch_comps): + labels = torch.unique(im_comps) + im_boxes = [] + im_scores = [] + for label in labels[labels != 0]: + y_loc, x_loc = torch.where(im_comps == label) + # add box + box = torch.Tensor([torch.min(x_loc), torch.min(y_loc), torch.max(x_loc), torch.max(y_loc)]).to( + masks.device, + ) + im_boxes.append(box) + if anomaly_maps is not None: + im_scores.append(torch.max(anomaly_maps[im_idx, y_loc, x_loc])) + batch_boxes.append(torch.stack(im_boxes) if im_boxes else torch.empty((0, 4), device=masks.device)) + batch_scores.append(torch.stack(im_scores) if im_scores else torch.empty(0, device=masks.device)) + + return batch_boxes, batch_scores + + +def boxes_to_masks(boxes: list[torch.Tensor], image_size: tuple[int, int]) -> torch.Tensor: + """Convert bounding boxes to segmentations masks. + + Args: + boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4) + containing the bounding box coordinates of the regions of interest in xyxy format. + image_size (tuple[int, int]): Image size of the output masks in (H, W) format. + + Returns: + Tensor: torch.Tensor of shape (B, H, W) in which each slice is a binary mask showing the pixels contained by a + bounding box. + """ + masks = torch.zeros((len(boxes), *image_size)).to(boxes[0].device) + for im_idx, im_boxes in enumerate(boxes): + for box in im_boxes: + x_1, y_1, x_2, y_2 = box.int() + masks[im_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = 1 + return masks + + +def boxes_to_anomaly_maps(boxes: torch.Tensor, scores: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: + """Convert bounding box coordinates to anomaly heatmaps. + + Args: + boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4) + containing the bounding box coordinates of the regions of interest in xyxy format. + scores (list[torch.Tensor]): A list of length B where each element is a 1D tensor of length N + containing the anomaly scores for each region of interest. + image_size (tuple[int, int]): Image size of the output masks in (H, W) format. + + Returns: + Tensor: torch.Tensor of shape (B, H, W). The pixel locations within each bounding box are collectively + assigned the anomaly score of the bounding box. In the case of overlapping bounding boxes, + the highest score is used. + """ + anomaly_maps = torch.zeros((len(boxes), *image_size)).to(boxes[0].device) + for im_idx, (im_boxes, im_scores) in enumerate(zip(boxes, scores, strict=False)): + im_map = torch.zeros((im_boxes.shape[0], *image_size)) + for box_idx, (box, score) in enumerate(zip(im_boxes, im_scores, strict=True)): + x_1, y_1, x_2, y_2 = box.int() + im_map[box_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = score + anomaly_maps[im_idx], _ = im_map.max(dim=0) + return anomaly_maps + + +def scale_boxes(boxes: torch.Tensor, image_size: torch.Size, new_size: torch.Size) -> torch.Tensor: + """Scale bbox coordinates to a new image size. + + Args: + boxes (torch.Tensor): Boxes of shape (N, 4) - (x1, y1, x2, y2). + image_size (Size): Size of the original image in which the bbox coordinates were retrieved. + new_size (Size): New image size to which the bbox coordinates will be scaled. + + Returns: + Tensor: Updated boxes of shape (N, 4) - (x1, y1, x2, y2). + """ + scale = torch.Tensor([*new_size]) / torch.Tensor([*image_size]) + return boxes * scale.repeat(2).to(boxes.device) diff --git a/anomalib/data/utils/download.py b/anomalib/data/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..558768b654d3924e91d0b595fbbcc38a62ca591d --- /dev/null +++ b/anomalib/data/utils/download.py @@ -0,0 +1,364 @@ +"""Helper to show progress bars with `urlretrieve`, check hash of file.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import hashlib +import io +import logging +import os +import re +import tarfile +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path +from tarfile import TarFile, TarInfo +from urllib.request import urlretrieve +from zipfile import ZipFile + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +@dataclass +class DownloadInfo: + """Info needed to download a dataset from a url.""" + + name: str + url: str + hashsum: str + filename: str | None = None + + +class DownloadProgressBar(tqdm): + """Create progress bar for urlretrieve. Subclasses `tqdm`. + + For information about the parameters in constructor, refer to `tqdm`'s documentation. + + Args: + iterable (Iterable | None): Iterable to decorate with a progressbar. + Leave blank to manually manage the updates. + desc (str | None): Prefix for the progressbar. + total (int | float | None): The number of expected iterations. If unspecified, + len(iterable) is used if possible. If float("inf") or as a last + resort, only basic progress statistics are displayed + (no ETA, no progressbar). + If `gui` is True and this parameter needs subsequent updating, + specify an initial arbitrary large positive number, + e.g. 9e9. + leave (bool | None): upon termination of iteration. If `None`, will leave only if `position` is `0`. + file (io.TextIOWrapper | io.StringIO | None): Specifies where to output the progress messages + (default: sys.stderr). Uses `file.write(str)` and + `file.flush()` methods. For encoding, see + `write_bytes`. + ncols (int | None): The width of the entire output message. If specified, + dynamically resizes the progressbar to stay within this bound. + If unspecified, attempts to use environment width. The + fallback is a meter width of 10 and no limit for the counter and + statistics. If 0, will not print any meter (only stats). + mininterval (float | None): Minimum progress display update interval [default: 0.1] seconds. + maxinterval (float | None): Maximum progress display update interval [default: 10] seconds. + Automatically adjusts `miniters` to correspond to `mininterval` + after long display update lag. Only works if `dynamic_miniters` + or monitor thread is enabled. + miniters (int | float | None): Minimum progress display update interval, in iterations. + If 0 and `dynamic_miniters`, will automatically adjust to equal + `mininterval` (more CPU efficient, good for tight loops). + If > 0, will skip display of specified number of iterations. + Tweak this and `mininterval` to get very efficient loops. + If your progress is erratic with both fast and slow iterations + (network, skipping items, etc) you should set miniters=1. + use_ascii (str | bool | None): If unspecified or False, use unicode (smooth blocks) to fill + the meter. The fallback is to use ASCII characters " 123456789#". + disable (bool | None): Whether to disable the entire progressbar wrapper + [default: False]. If set to None, disable on non-TTY. + unit (str | None): String that will be used to define the unit of each iteration + [default: it]. + unit_scale (int | float | bool): If 1 or True, the number of iterations will be reduced/scaled + automatically and a metric prefix following the + International System of Units standard will be added + (kilo, mega, etc.) [default: False]. If any other non-zero + number, will scale `total` and `n`. + dynamic_ncols (bool | None): If set, constantly alters `ncols` and `nrows` to the + environment (allowing for window resizes) [default: False]. + smoothing (float | None): Exponential moving average smoothing factor for speed estimates + (ignored in GUI mode). Ranges from 0 (average speed) to 1 + (current/instantaneous speed) [default: 0.3]. + bar_format (str | None): Specify a custom bar string formatting. May impact performance. + [default: '{l_bar}{bar}{r_bar}'], where + l_bar='{desc}: {percentage:3.0f}%|' and + r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' + '{rate_fmt}{postfix}]' + Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, + percentage, elapsed, elapsed_s, ncols, nrows, desc, unit, + rate, rate_fmt, rate_noinv, rate_noinv_fmt, + rate_inv, rate_inv_fmt, postfix, unit_divisor, + remaining, remaining_s, eta. + Note that a trailing ": " is automatically removed after {desc} + if the latter is empty. + initial (int | float | None): The initial counter value. Useful when restarting a progress + bar [default: 0]. If using float, consider specifying `{n:.3f}` + or similar in `bar_format`, or specifying `unit_scale`. + position (int | None): Specify the line offset to print this bar (starting from 0) + Automatic if unspecified. + Useful to manage multiple bars at once (eg, from threads). + postfix (dict | None): Specify additional stats to display at the end of the bar. + Calls `set_postfix(**postfix)` if possible (dict). + unit_divisor (float | None): [default: 1000], ignored unless `unit_scale` is True. + write_bytes (bool | None): If (default: None) and `file` is unspecified, + bytes will be written in Python 2. If `True` will also write + bytes. In all other cases will default to unicode. + lock_args (tuple | None): Passed to `refresh` for intermediate output + (initialisation, iterating, and updating). + nrows (int | None): The screen height. If specified, hides nested bars + outside this bound. If unspecified, attempts to use environment height. + The fallback is 20. + colour (str | None): Bar colour (e.g. 'green', '#00ff00'). + delay (float | None): Don't display until [default: 0] seconds have elapsed. + gui (bool | None): WARNING: internal parameter - do not use. + Use tqdm.gui.tqdm(...) instead. If set, will attempt to use + matplotlib animations for a graphical output [default: False]. + + + Example: + >>> with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as p_bar: + >>> urllib.request.urlretrieve(url, filename=output_path, reporthook=p_bar.update_to) + """ + + def __init__( + self, + iterable: Iterable | None = None, + desc: str | None = None, + total: int | float | None = None, + leave: bool | None = True, + file: io.TextIOWrapper | io.StringIO | None = None, + ncols: int | None = None, + mininterval: float | None = 0.1, + maxinterval: float | None = 10.0, + miniters: int | float | None = None, + use_ascii: bool | str | None = None, + disable: bool | None = False, + unit: str | None = "it", + unit_scale: bool | int | float | None = False, + dynamic_ncols: bool | None = False, + smoothing: float | None = 0.3, + bar_format: str | None = None, + initial: int | float | None = 0, + position: int | None = None, + postfix: dict | None = None, + unit_divisor: float | None = 1000, + write_bytes: bool | None = None, + lock_args: tuple | None = None, + nrows: int | None = None, + colour: str | None = None, + delay: float | None = 0, + gui: bool | None = False, + **kwargs, + ) -> None: + super().__init__( + iterable=iterable, + desc=desc, + total=total, + leave=leave, + file=file, + ncols=ncols, + mininterval=mininterval, + maxinterval=maxinterval, + miniters=miniters, + ascii=use_ascii, + disable=disable, + unit=unit, + unit_scale=unit_scale, + dynamic_ncols=dynamic_ncols, + smoothing=smoothing, + bar_format=bar_format, + initial=initial, + position=position, + postfix=postfix, + unit_divisor=unit_divisor, + write_bytes=write_bytes, + lock_args=lock_args, + nrows=nrows, + colour=colour, + delay=delay, + gui=gui, + **kwargs, + ) + self.total: int | float | None + + def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size: int | None = None) -> None: + """Progress bar hook for tqdm. + + Based on https://stackoverflow.com/a/53877507 + The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve. + However the context needs a few parameters. Refer to the example. + + Args: + chunk_number (int, optional): The current chunk being processed. Defaults to 1. + max_chunk_size (int, optional): Maximum size of each chunk. Defaults to 1. + total_size (int, optional): Total download size. Defaults to None. + """ + if total_size is not None: + self.total = total_size + self.update(chunk_number * max_chunk_size - self.n) + + +def is_file_potentially_dangerous(file_name: str) -> bool: + """Check if a file is potentially dangerous. + + Args: + file_name (str): Filename. + + Returns: + bool: True if the member is potentially dangerous, False otherwise. + + """ + # Some example criteria. We could expand this. + unsafe_patterns = ["/etc/", "/root/"] + return any(re.search(pattern, file_name) for pattern in unsafe_patterns) + + +def safe_extract(tar_file: TarFile, root: Path, members: list[TarInfo]) -> None: + """Extract safe members from a tar archive. + + Args: + tar_file (TarFile): TarFile object. + root (Path): Root directory where the dataset will be stored. + members (List[TarInfo]): List of safe members to be extracted. + + """ + for member in members: + tar_file.extract(member, root) + + +def generate_hash(file_path: str | Path, algorithm: str = "sha256") -> str: + """Generate a hash of a file using the specified algorithm. + + Args: + file_path (str | Path): Path to the file to hash. + algorithm (str): The hashing algorithm to use (e.g., 'sha256', 'sha3_512'). + + Returns: + str: The hexadecimal hash string of the file. + + Raises: + ValueError: If the specified hashing algorithm is not supported. + """ + # Get the hashing algorithm. + try: + hasher = getattr(hashlib, algorithm)() + except AttributeError as err: + msg = f"Unsupported hashing algorithm: {algorithm}" + raise ValueError(msg) from err + + # Read the file in chunks to avoid loading it all into memory + with Path(file_path).open("rb") as file: + for chunk in iter(lambda: file.read(4096), b""): + hasher.update(chunk) + + # Return the computed hash value in hexadecimal format + return hasher.hexdigest() + + +def check_hash(file_path: Path, expected_hash: str, algorithm: str = "sha256") -> None: + """Raise value error if hash does not match the calculated hash of the file. + + Args: + file_path (Path): Path to file. + expected_hash (str): Expected hash of the file. + algorithm (str): Hashing algorithm to use ('sha256', 'sha3_512', etc.). + """ + # Compare the calculated hash with the expected hash + calculated_hash = generate_hash(file_path, algorithm) + if calculated_hash != expected_hash: + msg = ( + f"Calculated hash {calculated_hash} of downloaded file {file_path} does not match the required hash " + f"{expected_hash}." + ) + raise ValueError(msg) + + +def extract(file_name: Path, root: Path) -> None: + """Extract a dataset. + + Args: + file_name (Path): Path of the file to be extracted. + root (Path): Root directory where the dataset will be stored. + + """ + logger.info("Extracting dataset into root folder.") + + # Safely extract zip files + if file_name.suffix == ".zip": + with ZipFile(file_name, "r") as zip_file: + for file_info in zip_file.infolist(): + if not is_file_potentially_dangerous(file_info.filename): + zip_file.extract(file_info, root) + + # Safely extract tar files. + elif file_name.suffix in (".tar", ".gz", ".xz", ".tgz"): + with tarfile.open(file_name) as tar_file: + members = tar_file.getmembers() + safe_members = [member for member in members if not is_file_potentially_dangerous(member.name)] + safe_extract(tar_file, root, safe_members) + + else: + msg = f"Unrecognized file format: {file_name}" + raise ValueError(msg) + + logger.info("Cleaning up files.") + file_name.unlink() + + +def download_and_extract(root: Path, info: DownloadInfo) -> None: + """Download and extract a dataset. + + Args: + root (Path): Root directory where the dataset will be stored. + info (DownloadInfo): Info needed to download the dataset. + """ + root.mkdir(parents=True, exist_ok=True) + + # save the compressed file in the specified root directory, using the same file name as on the server + downloaded_file_path = root / info.filename if info.filename else root / info.url.split("/")[-1] + + if downloaded_file_path.exists(): + logger.info("Existing dataset archive found. Skipping download stage.") + else: + logger.info("Downloading the %s dataset.", info.name) + # audit url. allowing only http:// or https:// + if info.url.startswith("http://") or info.url.startswith("https://"): + with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar: + urlretrieve( # noqa: S310 # nosec B310 + url=f"{info.url}", + filename=downloaded_file_path, + reporthook=progress_bar.update_to, + ) + logger.info("Checking the hash of the downloaded file.") + check_hash(downloaded_file_path, info.hashsum) + else: + msg = f"Invalid URL to download dataset. Supported 'http://' or 'https://' but '{info.url}' is requested" + raise RuntimeError(msg) + + extract(downloaded_file_path, root) + + +def is_within_directory(directory: Path, target: Path) -> bool: + """Check if a target path is located within a given directory. + + Args: + directory (Path): path of the parent directory + target (Path): path of the target + + Returns: + (bool): True if the target is within the directory, False otherwise + """ + abs_directory = directory.resolve() + abs_target = target.resolve() + + # TODO(djdameln): Replace with pathlib is_relative_to after switching to Python 3.10 + # CVS-122655 + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == str(abs_directory) diff --git a/anomalib/data/utils/generators/__init__.py b/anomalib/data/utils/generators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a79bad9770640d2dbd2b66f468b8cc4147794d5d --- /dev/null +++ b/anomalib/data/utils/generators/__init__.py @@ -0,0 +1,8 @@ +"""Utilities to generate synthetic data.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .perlin import random_2d_perlin + +__all__ = ["random_2d_perlin"] diff --git a/anomalib/data/utils/generators/perlin.py b/anomalib/data/utils/generators/perlin.py new file mode 100644 index 0000000000000000000000000000000000000000..fa683d7546589554de8433d51f8a557335849f3d --- /dev/null +++ b/anomalib/data/utils/generators/perlin.py @@ -0,0 +1,160 @@ +"""Helper functions for generating Perlin noise.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa + +import math + +import numpy as np +import torch + + +def lerp_np(x, y, w): + """Helper function.""" + return (y - x) * w + x + + +def rand_perlin_2d_octaves_np(shape, res, octaves=1, persistence=0.5): + """Generate Perlin noise parameterized by the octaves method. Numpy version.""" + noise = np.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * generate_perlin_noise_2d(shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise + + +def generate_perlin_noise_2d(shape, res): + """Fractal perlin noise.""" + + def f(t): + return 6 * t**5 - 15 * t**4 + 10 * t**3 + + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + # Gradients + angles = 2 * np.pi * np.random.default_rng().random(res[0] + 1, res[1] + 1) + gradients = np.dstack((np.cos(angles), np.sin(angles))) + g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) + g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) + g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) + g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) + # Ramps + n00 = np.sum(grid * g00, 2) + n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) + n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) + n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) + # Interpolation + t = f(grid) + n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 + n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 + return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) + + +def random_2d_perlin( + shape: tuple, + res: tuple[int | torch.Tensor, int | torch.Tensor], + fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3, +) -> np.ndarray | torch.Tensor: + """Returns a random 2d perlin noise array. + + Args: + shape (tuple): Shape of the 2d map. + res (tuple[int | torch.Tensor, int | torch.Tensor]): Tuple of scales for perlin noise for height and width dimension. + fade (_type_, optional): Function used for fading the resulting 2d map. + Defaults to equation 6*t**5-15*t**4+10*t**3. + + Returns: + np.ndarray | torch.Tensor: Random 2d-array/tensor generated using perlin noise. + """ + if isinstance(res[0], int | np.integer): + result = _rand_perlin_2d_np(shape, res, fade) + elif isinstance(res[0], torch.Tensor): + result = _rand_perlin_2d(shape, res, fade) + else: + msg = f"got scales of type {type(res[0])}" + raise TypeError(msg) + return result + + +def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + """Generate a random image containing Perlin noise. Numpy version.""" + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + + angles = 2 * math.pi * np.random.default_rng().random((res[0] + 1, res[1] + 1)) + gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) + + def tile_grads(slice1, slice2): + return np.repeat(np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), d[1], axis=1) + + def dot(grad, shift): + return ( + np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) + * grad[: shape[0], : shape[1]] + ).sum(axis=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) + + +def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + """Generate a random image containing Perlin noise. PyTorch version.""" + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 + angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + def tile_grads(slice1, slice2): + return ( + gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) + + def dot(grad, shift): + return ( + torch.stack( + (grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), + dim=-1, + ) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): + """Generate Perlin noise parameterized by the octaves method. PyTorch version.""" + noise = torch.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * _rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise diff --git a/anomalib/data/utils/image.py b/anomalib/data/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..1c3a71fd3d3a0b7b074db1762cddea700d4ab621 --- /dev/null +++ b/anomalib/data/utils/image.py @@ -0,0 +1,471 @@ +"""Image Utils.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import math +from collections.abc import Sequence +from pathlib import Path + +import cv2 +import numpy as np +import tifffile as tiff +import torch +from matplotlib.figure import Figure +from PIL import Image +from torch.nn import functional as F # noqa: N812 +from torchvision.datasets.folder import IMG_EXTENSIONS +# from torchvision.transforms.v2.functional import to_tensor,to_image_tensor +from torchvision.transforms.v2.functional import to_dtype,to_image + +from torchvision.tv_tensors import Mask + +from anomalib.data.utils.path import validate_path + +logger = logging.getLogger(__name__) + + +def is_image_file(filename: str | Path) -> bool: + """Check if the filename is an image file. + + Args: + filename (str | Path): Filename to check. + + Returns: + bool: True if the filename is an image file. + + Examples: + >>> is_image_file("000.png") + True + + >>> is_image_file("002.JPEG") + True + + >>> is_image_file("009.tiff") + True + + >>> is_image_file("002.avi") + False + """ + filename = Path(filename) + return filename.suffix.lower() in IMG_EXTENSIONS + + +def get_image_filename(filename: str | Path) -> Path: + """Get image filename. + + Args: + filename (str | Path): Filename to check. + + Returns: + Path: Image filename. + + Examples: + Assume that we have the following files in the directory: + + .. code-block:: bash + + $ ls + 000.png 001.jpg 002.JPEG 003.tiff 004.png 005.txt + + >>> get_image_filename("000.png") + PosixPath('000.png') + + >>> get_image_filename("001.jpg") + PosixPath('001.jpg') + + >>> get_image_filename("009.tiff") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_filename + FileNotFoundError: File not found: 009.tiff + + >>> get_image_filename("005.txt") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_filename + ValueError: ``filename`` is not an image file. 005.txt + """ + filename = Path(filename) + + if not filename.exists(): + msg = f"File not found: {filename}" + raise FileNotFoundError(msg) + + if not is_image_file(filename): + msg = f"``filename`` is not an image file: {filename}" + raise ValueError(msg) + return filename + + +def get_image_filenames_from_dir(path: str | Path) -> list[Path]: + """Get image filenames from directory. + + Args: + path (str | Path): Path to image directory. + + Raises: + ValueError: When ``path`` is not a directory. + + Returns: + list[Path]: Image filenames. + + Examples: + Assume that we have the following files in the directory: + $ ls + 000.png 001.jpg 002.JPEG 003.tiff 004.png 005.png + + >>> get_image_filenames_from_dir(".") + [PosixPath('000.png'), PosixPath('001.jpg'), PosixPath('002.JPEG'), + PosixPath('003.tiff'), PosixPath('004.png'), PosixPath('005.png')] + + >>> get_image_filenames_from_dir("009.tiff") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_filenames_from_dir + ValueError: ``path`` is not a directory: 009.tiff + """ + path = Path(path) + if not path.is_dir(): + msg = f"Path is not a directory: {path}" + raise ValueError(msg) + + image_filenames = [get_image_filename(f) for f in path.glob("**/*") if is_image_file(f)] + + if not image_filenames: + msg = f"Found 0 images in {path}" + raise ValueError(msg) + + return sorted(image_filenames) + + +def get_image_filenames(path: str | Path, base_dir: str | Path | None = None) -> list[Path]: + """Get image filenames. + + Args: + path (str | Path): Path to image or image-folder. + base_dir (Path): Base directory to restrict file access. + + Returns: + list[Path]: List of image filenames. + + Examples: + Assume that we have the following files in the directory: + + .. code-block:: bash + + $ tree images + images + ├── bad + │ ├── 003.png + │ └── 004.jpg + └── good + ├── 000.png + └── 001.tiff + + We can get the image filenames with various ways: + + >>> get_image_filenames("images/bad/003.png") + PosixPath('/home/sakcay/Projects/anomalib/images/bad/003.png')] + + It is possible to recursively get the image filenames from a directory: + + >>> get_image_filenames("images") + [PosixPath('/home/sakcay/Projects/anomalib/images/bad/003.png'), + PosixPath('/home/sakcay/Projects/anomalib/images/bad/004.jpg'), + PosixPath('/home/sakcay/Projects/anomalib/images/good/001.tiff'), + PosixPath('/home/sakcay/Projects/anomalib/images/good/000.png')] + + If we want to restrict the file access to a specific directory, + we can use ``base_dir`` argument. + + >>> get_image_filenames("images", base_dir="images/bad") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_filenames + ValueError: Access denied: Path is outside the allowed directory. + """ + path = validate_path(path, base_dir) + image_filenames: list[Path] = [] + + if path.is_file(): + image_filenames = [get_image_filename(path)] + elif path.is_dir(): + image_filenames = get_image_filenames_from_dir(path) + else: + msg = "Path is not a file or directory" + raise FileNotFoundError(msg) + + return image_filenames + + +def duplicate_filename(path: str | Path) -> Path: + """Check and duplicate filename. + + This function checks the path and adds a suffix if it already exists on the file system. + + Args: + path (str | Path): Input Path + + Examples: + >>> path = Path("datasets/MVTec/bottle/test/broken_large/000.png") + >>> path.exists() + True + + If we pass this to ``duplicate_filename`` function we would get the following: + >>> duplicate_filename(path) + PosixPath('datasets/MVTec/bottle/test/broken_large/000_1.png') + + Returns: + Path: Duplicated output path. + """ + path = Path(path) + + if not path.exists(): + return path + + i = 0 + while True: + duplicated_path = path if i == 0 else path.parent / (path.stem + f"_{i}" + path.suffix) + if not duplicated_path.exists(): + break + i += 1 + + return duplicated_path + + +def generate_output_image_filename(input_path: str | Path, output_path: str | Path) -> Path: + """Generate an output filename to save the inference image. + + This function generates an output filaname by checking the input and output filenames. Input path is + the input to infer, and output path is the path to save the output predictions specified by the user. + + The function expects ``input_path`` to always be a file, not a directory. ``output_path`` could be a + filename or directory. If it is a filename, the function checks if the specified filename exists on + the file system. If yes, the function calls ``duplicate_filename`` to duplicate the filename to avoid + overwriting the existing file. If ``output_path`` is a directory, this function adds the parent and + filenames of ``input_path`` to ``output_path``. + + Args: + input_path (str | Path): Path to the input image to infer. + output_path (str | Path): Path to output to save the predictions. + Could be a filename or a directory. + + Examples: + >>> input_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") + >>> output_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") + >>> generate_output_image_filename(input_path, output_path) + PosixPath('datasets/MVTec/bottle/test/broken_large/000_1.png') + + >>> input_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") + >>> output_path = Path("results/images") + >>> generate_output_image_filename(input_path, output_path) + PosixPath('results/images/broken_large/000.png') + + Raises: + ValueError: When the ``input_path`` is not a file. + + Returns: + Path: The output filename to save the output predictions from the inferencer. + """ + input_path = validate_path(input_path) + output_path = validate_path(output_path, should_exist=False) + + # Input validation: Check if input_path is a valid directory or file + if input_path.is_file() is False: + msg = "input_path is expected to be a file to generate a proper output filename." + raise ValueError(msg) + + # If the output is a directory, then add parent directory name + # and filename to the path. This is to ensure we do not overwrite + # images and organize based on the categories. + if output_path.is_dir(): + output_image_filename = output_path / input_path.parent.name / input_path.name + elif output_path.is_file() and output_path.exists(): + msg = f"{output_path} already exists. Renaming the file to avoid overwriting." + logger.warning(msg) + output_image_filename = duplicate_filename(output_path) + else: + output_image_filename = output_path + + output_image_filename.parent.mkdir(parents=True, exist_ok=True) + + return output_image_filename + + +def get_image_height_and_width(image_size: int | Sequence[int]) -> tuple[int, int]: + """Get image height and width from ``image_size`` variable. + + Args: + image_size (int | Sequence[int] | None, optional): Input image size. + + Raises: + ValueError: Image size not None, int or Sequence of values. + + Examples: + >>> get_image_height_and_width(image_size=256) + (256, 256) + + >>> get_image_height_and_width(image_size=(256, 256)) + (256, 256) + + >>> get_image_height_and_width(image_size=(256, 256, 3)) + (256, 256) + + >>> get_image_height_and_width(image_size=256.) + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_height_and_width + ValueError: ``image_size`` could be either int or tuple[int, int] + + Returns: + tuple[int | None, int | None]: A tuple containing image height and width values. + """ + if isinstance(image_size, int): + height_and_width = (image_size, image_size) + elif isinstance(image_size, Sequence): + height_and_width = int(image_size[0]), int(image_size[1]) + else: + msg = "``image_size`` could be either int or tuple[int, int]" + raise TypeError(msg) + + return height_and_width + + +def read_image(path: str | Path, as_tensor: bool = False) -> torch.Tensor | np.ndarray: + """Read image from disk in RGB format. + + Args: + path (str, Path): path to the image file + as_tensor (bool, optional): If True, returns the image as a tensor. Defaults to False. + + Example: + >>> image = read_image("test_image.jpg") + >>> type(image) + + >>> + >>> image = read_image("test_image.jpg", as_tensor=True) + >>> type(image) + + + Returns: + image as numpy array + """ + image = Image.open(path).convert("RGB") + # return to_tensor(to_image_tensor(image), torch.float32, scale=True) if as_tensor else np.array(image) / 255.0 + return to_dtype(to_image(image), torch.float32, scale=True) if as_tensor else np.array(image) / 255.0 + + +def read_mask(path: str | Path, as_tensor: bool = False) -> torch.Tensor | np.ndarray: + """Read mask from disk. + + Args: + path (str, Path): path to the mask file + as_tensor (bool, optional): If True, returns the mask as a tensor. Defaults to False. + + Example: + >>> mask = read_mask("test_mask.png") + >>> type(mask) + + >>> + >>> mask = read_mask("test_mask.png", as_tensor=True) + >>> type(mask) + + """ + image = Image.open(path).convert("L") + return Mask(to_image(image).squeeze() / 255, dtype=torch.uint8) if as_tensor else np.array(image) + + +def read_depth_image(path: str | Path) -> np.ndarray: + """Read tiff depth image from disk. + + Args: + path (str, Path): path to the image file + + Example: + >>> image = read_depth_image("test_image.tiff") + + Returns: + image as numpy array + """ + path = path if isinstance(path, str) else str(path) + return tiff.imread(path) + + +def pad_nextpow2(batch: torch.Tensor) -> torch.Tensor: + """Compute required padding from input size and return padded images. + + Finds the largest dimension and computes a square image of dimensions that are of the power of 2. + In case the image dimension is odd, it returns the image with an extra padding on one side. + + Args: + batch (torch.Tensor): Input images + + Returns: + batch: Padded batch + """ + # find the largest dimension + l_dim = 2 ** math.ceil(math.log(max(*batch.shape[-2:]), 2)) + padding_w = [math.ceil((l_dim - batch.shape[-2]) / 2), math.floor((l_dim - batch.shape[-2]) / 2)] + padding_h = [math.ceil((l_dim - batch.shape[-1]) / 2), math.floor((l_dim - batch.shape[-1]) / 2)] + return F.pad(batch, pad=[*padding_h, *padding_w]) + + +def show_image(image: np.ndarray | Figure, title: str = "Image") -> None: + """Show an image on the screen. + + Args: + image (np.ndarray | Figure): Image that will be shown in the window. + title (str, optional): Title that will be given to that window. Defaults to "Image". + """ + if isinstance(image, Figure): + image = figure_to_array(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imshow(title, image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +def save_image(filename: Path | str, image: np.ndarray | Figure, root: Path | None = None) -> None: + """Save an image to the file system. + + Args: + filename (Path | str): Path or filename to which the image will be saved. + image (np.ndarray | Figure): Image that will be saved to the file system. + root (Path, optional): Root directory to save the image. If provided, the top level directory of an absolute + filename will be overwritten. Defaults to None. + """ + if isinstance(image, Figure): + image = figure_to_array(image) + + file_path = Path(filename) + # if file_path is absolute, then root is ignored + # so we remove the top level directory from the path + if file_path.is_absolute() and root: + file_path = Path(*file_path.parts[2:]) # OS-AGNOSTIC + if root: + file_path = root / file_path + + # Make unique file_path if file already exists + file_path = duplicate_filename(file_path) + + file_path.parent.mkdir(parents=True, exist_ok=True) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(str(file_path), image) + + +def figure_to_array(fig: Figure) -> np.ndarray: + """Convert a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure. + + Returns: + np.ndarray: Numpy array containing the image. + """ + fig.canvas.draw() + # convert figure to np.ndarray for saving via visualizer + img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + return img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) diff --git a/anomalib/data/utils/label.py b/anomalib/data/utils/label.py new file mode 100644 index 0000000000000000000000000000000000000000..28908c816958ed90c1175b3593894e4af244558e --- /dev/null +++ b/anomalib/data/utils/label.py @@ -0,0 +1,13 @@ +"""Label name enum class.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + + +class LabelName(int, Enum): + """Name of label.""" + + NORMAL = 0 + ABNORMAL = 1 diff --git a/anomalib/data/utils/path.py b/anomalib/data/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0435be41117dbd85f0f314dca7d1a0b47f80fd --- /dev/null +++ b/anomalib/data/utils/path.py @@ -0,0 +1,235 @@ +"""Path Utils.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import os +import re +from enum import Enum +from pathlib import Path + +from torchvision.datasets.folder import IMG_EXTENSIONS + + +class DirType(str, Enum): + """Dir type names.""" + + NORMAL = "normal" + ABNORMAL = "abnormal" + NORMAL_TEST = "normal_test" + NORMAL_DEPTH = "normal_depth" + ABNORMAL_DEPTH = "abnormal_depth" + NORMAL_TEST_DEPTH = "normal_test_depth" + MASK = "mask_dir" + + +def _check_and_convert_path(path: str | Path) -> Path: + """Check an input path, and convert to Pathlib object. + + Args: + path (str | Path): Input path. + + Returns: + Path: Output path converted to pathlib object. + """ + if not isinstance(path, Path): + path = Path(path) + return path + + +def _prepare_files_labels( + path: str | Path, + path_type: str, + extensions: tuple[str, ...] | None = None, +) -> tuple[list, list]: + """Return a list of filenames and list corresponding labels. + + Args: + path (str | Path): Path to the directory containing images. + path_type (str): Type of images in the provided path ("normal", "abnormal", "normal_test") + extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the + directory. + + Returns: + List, List: Filenames of the images provided in the paths, labels of the images provided in the paths + """ + path = _check_and_convert_path(path) + if extensions is None: + extensions = IMG_EXTENSIONS + + if isinstance(extensions, str): + extensions = (extensions,) + + if not all(extension.startswith(".") for extension in extensions): + msg = f"All extensions {extensions} must start with the dot" + raise RuntimeError(msg) + + filenames = [ + f + for f in path.glob("**/*") + if f.suffix in extensions and not f.is_dir() and not any(part.startswith(".") for part in f.parts) + ] + if not filenames: + msg = f"Found 0 {path_type} images in {path} with extensions {extensions}" + raise RuntimeError(msg) + + labels = [path_type] * len(filenames) + + return filenames, labels + + +def resolve_path(folder: str | Path, root: str | Path | None = None) -> Path: + """Combine root and folder and returns the absolute path. + + This allows users to pass either a root directory and relative paths, or absolute paths to each of the + image sources. This function makes sure that the samples dataframe always contains absolute paths. + + Args: + folder (str | Path | None): Folder location containing image or mask data. + root (str | Path | None): Root directory for the dataset. + """ + folder = Path(folder) + if folder.is_absolute(): + path = folder + # path is relative. + elif root is None: + # no root provided; return absolute path + path = folder.resolve() + else: + # root provided; prepend root and return absolute path + path = (Path(root) / folder).resolve() + return path + + +def is_path_too_long(path: str | Path, max_length: int = 512) -> bool: + r"""Check if the path contains too long input. + + Args: + path (str | Path): Path to check. + max_length (int): Maximum length a path can be before it is considered too long. + Defaults to ``512``. + + Returns: + bool: True if the path contains too long input, False otherwise. + + Examples: + >>> contains_too_long_input("./datasets/MVTec/bottle/train/good/000.png") + False + + >>> contains_too_long_input("./datasets/MVTec/bottle/train/good/000.png" + "a" * 4096) + True + """ + return len(str(path)) > max_length + + +def contains_non_printable_characters(path: str | Path) -> bool: + r"""Check if the path contains non-printable characters. + + Args: + path (str | Path): Path to check. + + Returns: + bool: True if the path contains non-printable characters, False otherwise. + + Examples: + >>> contains_non_printable_characters("./datasets/MVTec/bottle/train/good/000.png") + False + + >>> contains_non_printable_characters("./datasets/MVTec/bottle/train/good/000.png\0") + True + """ + printable_pattern = re.compile(r"^[\x20-\x7E]+$") + return not printable_pattern.match(str(path)) + + +def validate_path(path: str | Path, base_dir: str | Path | None = None, should_exist: bool = True) -> Path: + """Validate the path. + + Args: + path (str | Path): Path to validate. + base_dir (str | Path): Base directory to restrict file access. + should_exist (bool): If True, do not raise an exception if the path does not exist. + + Returns: + Path: Validated path. + + Examples: + >>> validate_path("./datasets/MVTec/bottle/train/good/000.png") + PosixPath('/abs/path/to/anomalib/datasets/MVTec/bottle/train/good/000.png') + + >>> validate_path("./datasets/MVTec/bottle/train/good/000.png", base_dir="./datasets/MVTec") + PosixPath('/abs/path/to/anomalib/datasets/MVTec/bottle/train/good/000.png') + + >>> validate_path("/path/to/unexisting/file") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in validate_path + FileNotFoundError: Path does not exist: /path/to/unexisting/file + + Accessing a file without read permission should raise PermissionError: + + .. note:: + + Note that, we are using ``/usr/local/bin`` directory as an example here. + If this directory does not exist on your system, this will raise + ``FileNotFoundError`` instead of ``PermissionError``. You could change + the directory to any directory that you do not have read permission. + + >>> validate_path("/bin/bash", base_dir="/bin/") + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in validate_path + PermissionError: Read permission denied for the file: /usr/local/bin + + """ + # Check if the path is of an appropriate type + if not isinstance(path, str | Path): + raise TypeError("Expected str, bytes or os.PathLike object, not " + type(path).__name__) + + # Check if the path is too long + if is_path_too_long(path): + msg = f"Path is too long: {path}" + raise ValueError(msg) + + # Check if the path contains non-printable characters + if contains_non_printable_characters(path): + msg = f"Path contains non-printable characters: {path}" + raise ValueError(msg) + + # Sanitize paths + path = Path(path).resolve() + base_dir = Path(base_dir).resolve() if base_dir else Path.home() + + # In case path ``should_exist``, the path is valid, and should be + # checked for read and execute permissions. + if should_exist: + # Check if the path exists + if not path.exists(): + msg = f"Path does not exist: {path}" + raise FileNotFoundError(msg) + + # Check the read and execute permissions + if not (os.access(path, os.R_OK) or os.access(path, os.X_OK)): + msg = f"Read or execute permissions denied for the path: {path}" + raise PermissionError(msg) + + return path + + +def validate_and_resolve_path( + folder: str | Path, + root: str | Path | None = None, + base_dir: str | Path | None = None, +) -> Path: + """Validate and resolve the path. + + Args: + folder (str | Path): Folder location containing image or mask data. + root (str | Path | None): Root directory for the dataset. + base_dir (str | Path | None): Base directory to restrict file access. + + Returns: + Path: Validated and resolved path. + """ + return validate_path(resolve_path(folder, root), base_dir) diff --git a/anomalib/data/utils/split.py b/anomalib/data/utils/split.py new file mode 100644 index 0000000000000000000000000000000000000000..566dba8c2841c871206147bfd5c81e12b4f8e911 --- /dev/null +++ b/anomalib/data/utils/split.py @@ -0,0 +1,141 @@ +"""Dataset Split Utils. + +This module contains function in regards to splitting normal images in training set, +and creating validation sets from test sets. + +These function are useful + - when the test set does not contain any normal images. + - when the dataset doesn't have a validation set. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import math +from collections.abc import Sequence +from enum import Enum +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from anomalib import data + +logger = logging.getLogger(__name__) + + +class Split(str, Enum): + """Split of a subset.""" + + TRAIN = "train" + VAL = "val" + TEST = "test" + + +class TestSplitMode(str, Enum): + """Splitting mode used to obtain subset.""" + + NONE = "none" + FROM_DIR = "from_dir" + SYNTHETIC = "synthetic" + + +class ValSplitMode(str, Enum): + """Splitting mode used to obtain validation subset.""" + + NONE = "none" + SAME_AS_TEST = "same_as_test" + FROM_TRAIN = "from_train" + FROM_TEST = "from_test" + SYNTHETIC = "synthetic" + FROM_DIR = "from_dir" + + +def concatenate_datasets(datasets: Sequence["data.AnomalibDataset"]) -> "data.AnomalibDataset": + """Concatenate multiple datasets into a single dataset object. + + Args: + datasets (Sequence[AnomalibDataset]): Sequence of at least two datasets. + + Returns: + AnomalibDataset: Dataset that contains the combined samples of all input datasets. + """ + concat_dataset = datasets[0] + for dataset in datasets[1:]: + concat_dataset += dataset + return concat_dataset + + +def random_split( + dataset: "data.AnomalibDataset", + split_ratio: float | Sequence[float], + label_aware: bool = False, + seed: int | None = None, +) -> list["data.AnomalibDataset"]: + """Perform a random split of a dataset. + + Args: + dataset (AnomalibDataset): Source dataset + split_ratio (Union[float, Sequence[float]]): Fractions of the splits that will be produced. The values in the + sequence must sum to 1. If a single value is passed, the ratio will be converted to + [1-split_ratio, split_ratio]. + label_aware (bool): When True, the relative occurrence of the different class labels of the source dataset will + be maintained in each of the subsets. + seed (int | None, optional): Seed that can be passed if results need to be reproducible + """ + if isinstance(split_ratio, float): + split_ratio = [1 - split_ratio, split_ratio] + + if not (math.isclose(sum(split_ratio), 1) and sum(split_ratio) <= 1): + msg = f"Split ratios must sum to 1, found {sum(split_ratio)}" + raise ValueError(msg) + + if not all(0 < ratio < 1 for ratio in split_ratio): + msg = f"All split ratios must be between 0 and 1, found {split_ratio}" + raise ValueError(msg) + + # create list of source data + if label_aware and "label_index" in dataset.samples: + indices_per_label = [group.index for _, group in dataset.samples.groupby("label_index")] + per_label_datasets = [dataset.subsample(indices) for indices in indices_per_label] + else: + per_label_datasets = [dataset] + + # outer list: per-label unique, inner list: random subsets with the given ratio + subsets: list[list["data.AnomalibDataset"]] = [] + # split each (label-aware) subset of source data + for label_dataset in per_label_datasets: + # get subset lengths + subset_lengths = [math.floor(len(label_dataset.samples) * ratio) for ratio in split_ratio] + for i in range(len(label_dataset.samples) - sum(subset_lengths)): + subset_idx = i % sum(subset_lengths) + subset_lengths[subset_idx] += 1 + if 0 in subset_lengths: + msg = """Zero subset length encountered during splitting. This means one of your subsets + might be empty or devoid of either normal or anomalous images.""" + logger.warning(msg) + + # perform random subsampling + random_state = torch.Generator().manual_seed(seed) if seed else None + indices = torch.randperm(len(label_dataset.samples), generator=random_state) + subsets.append( + [label_dataset.subsample(subset_indices) for subset_indices in torch.split(indices, subset_lengths)], + ) + + # invert outer/inner lists + # outer list: subsets with the given ratio, inner list: per-label unique + subsets = list(map(list, zip(*subsets, strict=True))) + return [concatenate_datasets(subset) for subset in subsets] + + +def split_by_label(dataset: "data.AnomalibDataset") -> tuple["data.AnomalibDataset", "data.AnomalibDataset"]: + """Split the dataset into the normal and anomalous subsets.""" + samples = dataset.samples + normal_indices = samples[samples.label_index == 0].index + anomalous_indices = samples[samples.label_index == 1].index + + normal_subset = dataset.subsample(list(normal_indices)) + anomalous_subset = dataset.subsample(list(anomalous_indices)) + return normal_subset, anomalous_subset diff --git a/anomalib/data/utils/synthetic.py b/anomalib/data/utils/synthetic.py new file mode 100644 index 0000000000000000000000000000000000000000..67b8dcef998cf8b93a0a7c5de062bcb88fcd724b --- /dev/null +++ b/anomalib/data/utils/synthetic.py @@ -0,0 +1,172 @@ +"""Dataset that generates synthetic anomalies. + +This dataset can be used when there is a lack of real anomalous data. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import math +import shutil +from copy import deepcopy +from pathlib import Path +from tempfile import mkdtemp + +import cv2 +import pandas as pd +from pandas import DataFrame, Series +from torchvision.transforms.v2 import Compose + +from anomalib import TaskType +from anomalib.data.base.dataset import AnomalibDataset +from anomalib.data.utils import Augmenter, Split, read_image + +logger = logging.getLogger(__name__) + + +ROOT = "./.tmp/synthetic_anomaly" + + +def make_synthetic_dataset( + source_samples: DataFrame, + image_dir: Path, + mask_dir: Path, + anomalous_ratio: float = 0.5, +) -> DataFrame: + """Convert a set of normal samples into a mixed set of normal and synthetic anomalous samples. + + The synthetic images will be saved to the file system in the specified root directory under /images. + For the synthetic anomalous images, the masks will be saved under /ground_truth. + + Args: + source_samples (DataFrame): Normal images that will be used as source for the synthetic anomalous images. + image_dir (Path): Directory to which the synthetic anomalous image files will be written. + mask_dir (Path): Directory to which the ground truth anomaly masks will be written. + anomalous_ratio (float): Fraction of source samples that will be converted into anomalous samples. + """ + if 1 in source_samples.label_index.to_numpy(): + msg = "All source images must be normal." + raise ValueError(msg) + + if not image_dir.is_dir(): + msg = f"{image_dir} is not a folder." + raise NotADirectoryError(msg) + + if not mask_dir.is_dir(): + msg = f"{mask_dir} is not a folder." + raise NotADirectoryError(msg) + + # filter relevant columns + source_samples = source_samples.filter(["image_path", "label", "label_index", "mask_path", "split"]) + # randomly select samples for augmentation + n_anomalous = int(anomalous_ratio * len(source_samples)) + anomalous_samples = source_samples.sample(n_anomalous) + normal_samples = source_samples.drop(anomalous_samples.index) + anomalous_samples = anomalous_samples.reset_index(drop=True) + + # initialize augmenter + augmenter = Augmenter("./datasets/dtd", p_anomalous=1.0, beta=(0.01, 0.2)) + + def augment(sample: Series) -> Series: + """Apply synthetic anomalous augmentation to a sample from a dataframe. + + Reads an image, applies the augmentations, writes the augmented image and corresponding mask to the file system, + and returns a new Series object with the updates labels and file locations. + + Args: + sample (Series): DataFrame row containing info about the image that will be augmented. + + Returns: + Series: DataFrame row with updated information about the augmented image. + """ + # read and transform image + image = read_image(sample.image_path, as_tensor=True) + # apply anomalous perturbation + aug_im, mask = augmenter.augment_batch(image.unsqueeze(0)) + # target file name with leading zeros + file_name = f"{str(sample.name).zfill(int(math.log10(n_anomalous)) + 1)}.png" + # write image + aug_im = (aug_im.squeeze().permute((1, 2, 0)) * 255).numpy() + aug_im = cv2.cvtColor(aug_im, cv2.COLOR_RGB2BGR) + im_path = image_dir / file_name + cv2.imwrite(str(im_path), aug_im) + # write mask + mask = (mask.squeeze() * 255).numpy() + mask_path = mask_dir / file_name + cv2.imwrite(str(mask_path), mask) + out = { + "image_path": str(im_path), + "label": "abnormal", + "label_index": 1, + "mask_path": str(mask_path), + "split": Split.VAL, + } + return Series(out) + + anomalous_samples = anomalous_samples.apply(augment, axis=1) + + return pd.concat([normal_samples, anomalous_samples], ignore_index=True) + + +class SyntheticAnomalyDataset(AnomalibDataset): + """Dataset which reads synthetically generated anomalous images from a temporary folder. + + Args: + task (str): Task type, either "classification" or "segmentation". + transform (A.Compose): Transform object describing the transforms that are applied to the inputs. + source_samples (DataFrame): Normal samples to which the anomalous augmentations will be applied. + """ + + def __init__(self, task: TaskType, transform: Compose, source_samples: DataFrame) -> None: + super().__init__(task, transform) + + self.source_samples = source_samples + + # Files will be written to a temporary directory in the workdir, which is cleaned up after code execution + root = Path(ROOT) + root.mkdir(parents=True, exist_ok=True) + + self.root = Path(mkdtemp(dir=root)) + self.im_dir = self.root / "abnormal" + self.mask_dir = self.root / "ground_truth" + + # create directories + self.im_dir.mkdir() + self.mask_dir.mkdir() + + self._cleanup = True # flag that determines if temp dir is cleaned up when instance is deleted + self.samples = make_synthetic_dataset(self.source_samples, self.im_dir, self.mask_dir, 0.5) + + @classmethod + def from_dataset(cls: type["SyntheticAnomalyDataset"], dataset: AnomalibDataset) -> "SyntheticAnomalyDataset": + """Create a synthetic anomaly dataset from an existing dataset of normal images. + + Args: + dataset (AnomalibDataset): Dataset consisting of only normal images that will be converrted to a synthetic + anomalous dataset with a 50/50 normal anomalous split. + """ + return cls(task=dataset.task, transform=dataset.transform, source_samples=dataset.samples) + + def __copy__(self) -> "SyntheticAnomalyDataset": + """Return a shallow copy of the dataset object and prevents cleanup when original object is deleted.""" + cls = self.__class__ + new = cls.__new__(cls) + new.__dict__.update(self.__dict__) + self._cleanup = False + return new + + def __deepcopy__(self, _memo: dict) -> "SyntheticAnomalyDataset": + """Return a deep copy of the dataset object and prevents cleanup when original object is deleted.""" + cls = self.__class__ + new = cls.__new__(cls) + for key, value in self.__dict__.items(): + setattr(new, key, deepcopy(value)) + self._cleanup = False + return new + + def __del__(self) -> None: + """Make sure the temporary directory is cleaned up when the dataset object is deleted.""" + if self._cleanup: + shutil.rmtree(self.root) diff --git a/anomalib/data/utils/tiler.py b/anomalib/data/utils/tiler.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6ad35dfa8d07343ffb9490696d1d19aa45147c --- /dev/null +++ b/anomalib/data/utils/tiler.py @@ -0,0 +1,425 @@ +"""Image Tiler.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Sequence +from enum import Enum +from itertools import product +from math import ceil + +import torch +import torchvision.transforms as T # noqa: N812 +from torch.nn import functional as F # noqa: N812 + + +class ImageUpscaleMode(str, Enum): + """Type of mode when upscaling image.""" + + PADDING = "padding" + INTERPOLATION = "interpolation" + + +class StrideSizeError(Exception): + """StrideSizeError to raise exception when stride size is greater than the tile size.""" + + +def compute_new_image_size(image_size: tuple, tile_size: tuple, stride: tuple) -> tuple: + """Check if image size is divisible by tile size and stride. + + If not divisible, it resizes the image size to make it divisible. + + Args: + image_size (tuple): Original image size + tile_size (tuple): Tile size + stride (tuple): Stride + + Examples: + >>> compute_new_image_size(image_size=(512, 512), tile_size=(256, 256), stride=(128, 128)) + (512, 512) + + >>> compute_new_image_size(image_size=(512, 512), tile_size=(222, 222), stride=(111, 111)) + (555, 555) + + Returns: + tuple: Updated image size that is divisible by tile size and stride. + """ + + def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: + """Resize within the edge level.""" + if (edge_size - tile_size) % stride != 0: + edge_size = (ceil((edge_size - tile_size) / stride) * stride) + tile_size + + return edge_size + + resized_h = __compute_new_edge_size(image_size[0], tile_size[0], stride[0]) + resized_w = __compute_new_edge_size(image_size[1], tile_size[1], stride[1]) + + return resized_h, resized_w + + +def upscale_image(image: torch.Tensor, size: tuple, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING) -> torch.Tensor: + """Upscale image to the desired size via either padding or interpolation. + + Args: + image (torch.Tensor): Image + size (tuple): tuple to which image is upscaled. + mode (str, optional): Upscaling mode. Defaults to "padding". + + Examples: + >>> image = torch.rand(1, 3, 512, 512) + >>> image = upscale_image(image, size=(555, 555), mode="padding") + >>> image.shape + torch.Size([1, 3, 555, 555]) + + >>> image = torch.rand(1, 3, 512, 512) + >>> image = upscale_image(image, size=(555, 555), mode="interpolation") + >>> image.shape + torch.Size([1, 3, 555, 555]) + + Returns: + Tensor: Upscaled image. + """ + image_h, image_w = image.shape[2:] + resize_h, resize_w = size + + if mode == ImageUpscaleMode.PADDING: + pad_h = resize_h - image_h + pad_w = resize_w - image_w + + image = F.pad(image, [0, pad_w, 0, pad_h]) + elif mode == ImageUpscaleMode.INTERPOLATION: + image = F.interpolate(input=image, size=(resize_h, resize_w)) + else: + msg = f"Unknown mode {mode}. Only padding and interpolation is available." + raise ValueError(msg) + + return image + + +def downscale_image( + image: torch.Tensor, + size: tuple, + mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, +) -> torch.Tensor: + """Opposite of upscaling. This image downscales image to a desired size. + + Args: + image (torch.Tensor): Input image + size (tuple): Size to which image is down scaled. + mode (str, optional): Downscaling mode. Defaults to "padding". + + Examples: + >>> x = torch.rand(1, 3, 512, 512) + >>> y = upscale_image(image, upscale_size=(555, 555), mode="padding") + >>> y = downscale_image(y, size=(512, 512), mode='padding') + >>> torch.allclose(x, y) + True + + Returns: + Tensor: Downscaled image + """ + input_h, input_w = size + if mode == ImageUpscaleMode.PADDING: + image = image[:, :, :input_h, :input_w] + else: + image = F.interpolate(input=image, size=(input_h, input_w)) + + return image + + +class Tiler: + """Tile Image into (non)overlapping Patches. Images are tiled in order to efficiently process large images. + + Args: + tile_size: Tile dimension for each patch + stride: Stride length between patches + remove_border_count: Number of border pixels to be removed from tile before untiling + mode: Upscaling mode for image resize.Supported formats: padding, interpolation + + Examples: + >>> import torch + >>> from torchvision import transforms + >>> from skimage.data import camera + >>> tiler = Tiler(tile_size=256,stride=128) + >>> image = transforms.ToTensor()(camera()) + >>> tiles = tiler.tile(image) + >>> image.shape, tiles.shape + (torch.Size([3, 512, 512]), torch.Size([9, 3, 256, 256])) + + >>> # Perform your operations on the tiles. + + >>> # Untile the patches to reconstruct the image + >>> reconstructed_image = tiler.untile(tiles) + >>> reconstructed_image.shape + torch.Size([1, 3, 512, 512]) + """ + + def __init__( + self, + tile_size: int | Sequence, + stride: int | Sequence | None = None, + remove_border_count: int = 0, + mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, + ) -> None: + self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) + self.random_tile_count = 4 + + if stride is not None: + self.stride_h, self.stride_w = self.__validate_size_type(stride) + + self.remove_border_count = remove_border_count + self.overlapping = not (self.stride_h == self.tile_size_h and self.stride_w == self.tile_size_w) + self.mode = mode + + if self.stride_h > self.tile_size_h or self.stride_w > self.tile_size_w: + msg = ( + "Larger stride size than kernel size produces unreliable tiling results. " + "Please ensure stride size is less than or equal than tiling size." + ) + raise StrideSizeError( + msg, + ) + + if self.mode not in (ImageUpscaleMode.PADDING, ImageUpscaleMode.INTERPOLATION): + msg = f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation" + raise ValueError(msg) + + self.batch_size: int + self.num_channels: int + + self.input_h: int + self.input_w: int + + self.pad_h: int + self.pad_w: int + + self.resized_h: int + self.resized_w: int + + self.num_patches_h: int + self.num_patches_w: int + + @staticmethod + def __validate_size_type(parameter: int | Sequence) -> tuple[int, ...]: + if isinstance(parameter, int): + output = (parameter, parameter) + elif isinstance(parameter, Sequence): + output = (parameter[0], parameter[1]) + else: + msg = f"Unknown type {type(parameter)} for tile or stride size. Could be int or Sequence type." + raise TypeError(msg) + + if len(output) != 2: + msg = f"Length of the size type must be 2 for height and width. Got {len(output)} instead." + raise ValueError(msg) + + return output + + def __random_tile(self, image: torch.Tensor) -> torch.Tensor: + """Randomly crop tiles from the given image. + + Args: + image: input image to be cropped + + Returns: Randomly cropped tiles from the image + """ + return torch.vstack([T.RandomCrop(self.tile_size_h)(image) for i in range(self.random_tile_count)]) + + def __unfold(self, tensor: torch.Tensor) -> torch.Tensor: + """Unfolds tensor into tiles. + + This is the core function to perform tiling operation. + + Args: + tensor: Input tensor from which tiles are generated. + + Returns: Generated tiles + """ + # identify device type based on input tensor + device = tensor.device + + # extract and calculate parameters + batch, channels, image_h, image_w = tensor.shape + + self.num_patches_h = int((image_h - self.tile_size_h) / self.stride_h) + 1 + self.num_patches_w = int((image_w - self.tile_size_w) / self.stride_w) + 1 + + # create an empty torch tensor for output + tiles = torch.zeros( + (self.num_patches_h, self.num_patches_w, batch, channels, self.tile_size_h, self.tile_size_w), + device=device, + ) + + # fill-in output tensor with spatial patches extracted from the image + for (tile_i, tile_j), (loc_i, loc_j) in zip( + product(range(self.num_patches_h), range(self.num_patches_w)), + product( + range(0, image_h - self.tile_size_h + 1, self.stride_h), + range(0, image_w - self.tile_size_w + 1, self.stride_w), + ), + strict=True, + ): + tiles[tile_i, tile_j, :] = tensor[ + :, + :, + loc_i : (loc_i + self.tile_size_h), + loc_j : (loc_j + self.tile_size_w), + ] + + # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] + tiles = tiles.permute(2, 0, 1, 3, 4, 5) + return tiles.contiguous().view(-1, channels, self.tile_size_h, self.tile_size_w) + + def __fold(self, tiles: torch.Tensor) -> torch.Tensor: + """Fold the tiles back into the original tensor. + + This is the core method to reconstruct the original image from its tiled version. + + Args: + tiles: Tiles from the input image, generated via __unfold method. + + Returns: + Output that is the reconstructed version of the input tensor. + """ + # number of channels differs between image and anomaly map, so infer from input tiles. + _, num_channels, tile_size_h, tile_size_w = tiles.shape + scale_h, scale_w = (tile_size_h / self.tile_size_h), (tile_size_w / self.tile_size_w) + # identify device type based on input tensor + device = tiles.device + # calculate tile size after borders removed + reduced_tile_h = tile_size_h - (2 * self.remove_border_count) + reduced_tile_w = tile_size_w - (2 * self.remove_border_count) + # reconstructed image dimension + image_size = (self.batch_size, num_channels, int(self.resized_h * scale_h), int(self.resized_w * scale_w)) + + # rearrange input tiles in format [tile_count, batch, channel, tile_h, tile_w] + tiles = tiles.contiguous().view( + self.batch_size, + self.num_patches_h, + self.num_patches_w, + num_channels, + tile_size_h, + tile_size_w, + ) + tiles = tiles.permute(0, 3, 1, 2, 4, 5) + tiles = tiles.contiguous().view(self.batch_size, num_channels, -1, tile_size_h, tile_size_w) + tiles = tiles.permute(2, 0, 1, 3, 4) + + # remove tile borders by defined count + tiles = tiles[ + :, + :, + :, + self.remove_border_count : reduced_tile_h + self.remove_border_count, + self.remove_border_count : reduced_tile_w + self.remove_border_count, + ] + + # create tensors to store intermediate results and outputs + img = torch.zeros(image_size, device=device) + lookup = torch.zeros(image_size, device=device) + ones = torch.ones(reduced_tile_h, reduced_tile_w, device=device) + + # reconstruct image by adding patches to their respective location and + # create a lookup for patch count in every location + for patch, (loc_i, loc_j) in zip( + tiles, + product( + range( + self.remove_border_count, + int(self.resized_h * scale_h) - reduced_tile_h + 1, + int(self.stride_h * scale_h), + ), + range( + self.remove_border_count, + int(self.resized_w * scale_w) - reduced_tile_w + 1, + int(self.stride_w * scale_w), + ), + ), + strict=True, + ): + img[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += patch + lookup[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += ones + + # divide the reconstucted image by the lookup to average out the values + img = torch.divide(img, lookup) + # alternative way of removing nan values (isnan not supported by openvino) + img[img != img] = 0 # noqa: PLR0124 + + return img + + def tile(self, image: torch.Tensor, use_random_tiling: bool = False) -> torch.Tensor: + """Tiles an input image to either overlapping, non-overlapping or random patches. + + Args: + image: Input image to tile. + use_random_tiling: If True, randomly crops tiles from the image. + If False, tiles the image in a regular grid. + + Examples: + >>> from anomalib.data.utils.tiler import Tiler + >>> tiler = Tiler(tile_size=512,stride=256) + >>> image = torch.rand(size=(2, 3, 1024, 1024)) + >>> image.shape + torch.Size([2, 3, 1024, 1024]) + >>> tiles = tiler.tile(image) + >>> tiles.shape + torch.Size([18, 3, 512, 512]) + + Returns: + Tiles generated from the image. + """ + if image.dim() == 3: + image = image.unsqueeze(0) + + self.batch_size, self.num_channels, self.input_h, self.input_w = image.shape + + if self.input_h < self.tile_size_h or self.input_w < self.tile_size_w: + msg = ( + f"One of the edges of the tile size {self.tile_size_h, self.tile_size_w} is larger than " + f"that of the image {self.input_h, self.input_w}." + ) + raise ValueError( + msg, + ) + + self.resized_h, self.resized_w = compute_new_image_size( + image_size=(self.input_h, self.input_w), + tile_size=(self.tile_size_h, self.tile_size_w), + stride=(self.stride_h, self.stride_w), + ) + + image = upscale_image(image, size=(self.resized_h, self.resized_w), mode=self.mode) + + return self.__random_tile(image) if use_random_tiling else self.__unfold(image) + + def untile(self, tiles: torch.Tensor) -> torch.Tensor: + """Untiles patches to reconstruct the original input image. + + If patches, are overlapping patches, the function averages the overlapping pixels, + and return the reconstructed image. + + Args: + tiles: Tiles from the input image, generated via tile().. + + Examples: + >>> from anomalib.data.utils.tiler import Tiler + >>> tiler = Tiler(tile_size=512,stride=256) + >>> image = torch.rand(size=(2, 3, 1024, 1024)) + >>> image.shape + torch.Size([2, 3, 1024, 1024]) + >>> tiles = tiler.tile(image) + >>> tiles.shape + torch.Size([18, 3, 512, 512]) + >>> reconstructed_image = tiler.untile(tiles) + >>> reconstructed_image.shape + torch.Size([2, 3, 1024, 1024]) + >>> torch.equal(image, reconstructed_image) + True + + Returns: + Output that is the reconstructed version of the input tensor. + """ + image = self.__fold(tiles) + return downscale_image(image=image, size=(self.input_h, self.input_w), mode=self.mode) diff --git a/anomalib/data/utils/video.py b/anomalib/data/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..7a939ea861a42678372f3044716c8a023c39d9d7 --- /dev/null +++ b/anomalib/data/utils/video.py @@ -0,0 +1,100 @@ +"""Video utils.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import cv2 +import torch +from torchvision.datasets.video_utils import VideoClips + + +class ClipsIndexer(VideoClips, ABC): + """Extension of torchvision's VideoClips class that also returns the masks for each clip. + + Subclasses should implement the get_mask method. By default, the class inherits the functionality of VideoClips, + which assumes that video_paths is a list of video files. If custom behaviour is required (e.g. video_paths is a list + of folders with single-frame images), the subclass should implement at least get_clip and _compute_frame_pts. + + Args: + video_paths (list[str]): List of video paths that make up the dataset. + mask_paths (list[str]): List of paths to the masks for each video in the dataset. + """ + + def __init__( + self, + video_paths: list[str], + mask_paths: list[str], + clip_length_in_frames: int = 2, + frames_between_clips: int = 1, + ) -> None: + super().__init__( + video_paths=video_paths, + clip_length_in_frames=clip_length_in_frames, + frames_between_clips=frames_between_clips, + output_format="TCHW", + ) + self.mask_paths = mask_paths + + def last_frame_idx(self, video_idx: int) -> int: + """Return the index of the last frame for a given video.""" + return self.clips[video_idx][-1][-1].item() + + @abstractmethod + def get_mask(self, idx: int) -> torch.Tensor | None: + """Return the masks for the given index.""" + raise NotImplementedError + + def get_item(self, idx: int) -> dict[str, Any]: + """Return a dictionary containing the clip, mask, video path and frame indices.""" + with warnings.catch_warnings(): + # silence warning caused by bug in torchvision, see https://github.com/pytorch/vision/issues/5787 + warnings.simplefilter("ignore") + clip, _, _, _ = self.get_clip(idx) + + video_idx, clip_idx = self.get_clip_location(idx) + video_path = self.video_paths[video_idx] + clip_pts = self.clips[video_idx][clip_idx] + + return { + "image": clip, + "mask": self.get_mask(idx), + "video_path": video_path, + "frames": clip_pts, + "last_frame": self.last_frame_idx(video_idx), + } + + +def convert_video(input_path: Path, output_path: Path, codec: str = "MP4V") -> None: + """Convert video file to a different codec. + + Args: + input_path (Path): Path to the input video. + output_path (Path): Path to the target output video. + codec (str): fourcc code of the codec that will be used for compression of the output file. + """ + if not output_path.parent.exists(): + output_path.parent.mkdir(parents=True) + + # create video reader for input file + video_reader = cv2.VideoCapture(str(input_path)) + + # create video writer for output file + fourcc = cv2.VideoWriter_fourcc(*codec) + frame_width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = int(video_reader.get(cv2.CAP_PROP_FPS)) + video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) + + # read frames + success, frame = video_reader.read() + while success: + video_writer.write(frame) + success, frame = video_reader.read() + + video_reader.release() + video_writer.release() diff --git a/anomalib/data/video/__init__.py b/anomalib/data/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9651529bff20c93f3765c527ad122171f84b275 --- /dev/null +++ b/anomalib/data/video/__init__.py @@ -0,0 +1,22 @@ +"""Anomalib Video Datasets.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +from .avenue import Avenue +from .shanghaitech import ShanghaiTech +from .ucsd_ped import UCSDped + + +class VideoDataFormat(str, Enum): + """Supported Video Dataset Types.""" + + UCSDPED = "ucsdped" + AVENUE = "avenue" + SHANGHAITECH = "shanghaitech" + + +__all__ = ["Avenue", "ShanghaiTech", "UCSDped"] diff --git a/anomalib/data/video/avenue.py b/anomalib/data/video/avenue.py new file mode 100644 index 0000000000000000000000000000000000000000..831caa4021eb7e267fbb67eadbda2db27d30d6ed --- /dev/null +++ b/anomalib/data/video/avenue.py @@ -0,0 +1,482 @@ +"""CUHK Avenue Dataset. + +Description: + This module provides a PyTorch Dataset and PyTorch Lightning DataModule for the CUHK Avenue dataset. + If the dataset is not already present on the file system, the DataModule class will download and + extract the dataset, converting the .mat mask files to .png format. + +Reference: + - Lu, Cewu, Jianping Shi, and Jiaya Jia. "Abnormal event detection at 150 fps in Matlab." + In Proceedings of the IEEE International Conference on Computer Vision, 2013. +""" + + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +__all__ = ["Avenue", "AvenueDataset", "make_avenue_dataset"] + +import logging +import math +from pathlib import Path +from shutil import move +from typing import TYPE_CHECKING + +import cv2 +import numpy as np +import scipy.io +import torch +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame +from anomalib.data.utils import ( + DownloadInfo, + Split, + ValSplitMode, + download_and_extract, + read_mask, + validate_path, +) +from anomalib.data.utils.video import ClipsIndexer + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = logging.getLogger(__name__) + +DATASET_DOWNLOAD_INFO = DownloadInfo( + name="Avenue Dataset", + url="http://www.cse.cuhk.edu.hk/leojia/projects/detectabnormal/Avenue_Dataset.zip", + hashsum="fc9cb8432a11ca79c18aa180c72524011411b69d3b0ff27c8816e41c0de61531", +) +ANNOTATIONS_DOWNLOAD_INFO = DownloadInfo( + name="Avenue Annotations", + url="http://www.cse.cuhk.edu.hk/leojia/projects/detectabnormal/ground_truth_demo.zip", + hashsum="60fec1728ec8f73a58aad3aeb5729d70a805a47e0b8eb4bf91ab67ef06386d77", +) + + +def make_avenue_dataset(root: Path, gt_dir: Path, split: Split | str | None = None) -> DataFrame: + """Create CUHK Avenue dataset by parsing the file structure. + + The files are expected to follow the structure: + - path/to/dataset/[training_videos|testing_videos]/video_filename.avi + - path/to/ground_truth/mask_filename.mat + + Args: + root (Path): Path to dataset + gt_dir (Path): Path to the ground truth + split (Split | str | None = None, optional): Dataset split (ie., either train or test). + Defaults to ``None``. + + Example: + The following example shows how to get testing samples from Avenue dataset: + + >>> root = Path('./avenue') + >>> gt_dir = Path('./avenue/masks') + >>> samples = make_avenue_dataset(path, gt_dir, split='test') + >>> samples.head() + root folder image_path mask_path split + 0 ./avenue testing_videos ./avenue/training_videos/01.avi ./avenue/masks/01_label.mat test + 1 ./avenue testing_videos ./avenue/training_videos/02.avi ./avenue/masks/01_label.mat test + ... + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + """ + root = validate_path(root) + + samples_list = [(str(root),) + filename.parts[-2:] for filename in root.glob("**/*.avi")] + samples = DataFrame(samples_list, columns=["root", "folder", "image_path"]) + + samples.loc[samples.folder == "testing_videos", "mask_path"] = ( + samples.image_path.str.split(".").str[0].str.lstrip("0") + "_label.mat" + ) + samples.loc[samples.folder == "testing_videos", "mask_path"] = ( + str(gt_dir) + "/testing_label_mask/" + samples.mask_path + ) + samples.loc[samples.folder == "training_videos", "mask_path"] = "" + + samples["image_path"] = samples.root + "/" + samples.folder + "/" + samples.image_path + + samples.loc[samples.folder == "training_videos", "split"] = "train" + samples.loc[samples.folder == "testing_videos", "split"] = "test" + + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class AvenueClipsIndexer(ClipsIndexer): + """Clips class for Avenue dataset.""" + + def get_mask(self, idx: int) -> np.ndarray | None: + """Retrieve the masks from the file system.""" + video_idx, frames_idx = self.get_clip_location(idx) + matfile = self.mask_paths[video_idx] + if matfile == "": # no gt masks available for this clip + return None + frames = self.clips[video_idx][frames_idx] + + # read masks from .png files if available, othwerise from mat files. + mask_folder = Path(matfile).with_suffix("") + if mask_folder.exists(): + mask_frames = sorted(mask_folder.glob("*")) + mask_paths = [mask_frames[idx] for idx in frames.int()] + masks = torch.stack([read_mask(mask_path, as_tensor=True) for mask_path in mask_paths]) + else: + mat = scipy.io.loadmat(matfile) + masks = np.vstack([np.stack(m) for m in mat["volLabel"]]) + masks = np.take(masks, frames, 0) + return masks + + +class AvenueDataset(AnomalibVideoDataset): + """Avenue Dataset class. + + Args: + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST + root (Path | str): Path to the root of the dataset + Defaults to ``./datasets/avenue``. + gt_dir (Path | str): Path to the ground truth files + Defaults to ``./datasets/avenue/ground_truth_demo``. + clip_length_in_frames (int, optional): Number of video frames in each clip. + Defaults to ``2``. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + Defaults to ``1``. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. + Defaults to ``VideoTargetFrame.LAST``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + + Examples: + To create an Avenue dataset to train a classification model: + + .. code-block:: python + + transform = A.Compose([A.Resize(256, 256), A.pytorch.ToTensorV2()]) + dataset = AvenueDataset( + task="classification", + transform=transform, + split="train", + root="./datasets/avenue/", + ) + + dataset.setup() + dataset[0].keys() + + # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + + If you would like to test a segmentation model, you can use the following code: + + .. code-block:: python + + dataset = AvenueDataset( + task="segmentation", + transform=transform, + split="test", + root="./datasets/avenue/", + ) + + dataset.setup() + dataset[0].keys() + + # Output: dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) + + Avenue video dataset can also be used as an image dataset if you set the clip length to 1. This means that each + video frame will be treated as a separate sample. This is useful for training a classification model on the + Avenue dataset. The following code shows how to create an image dataset for classification: + + .. code-block:: python + + dataset = AvenueDataset( + task="classification", + transform=transform, + split="test", + root="./datasets/avenue/", + clip_length_in_frames=1, + ) + + dataset.setup() + dataset[0].keys() + # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) + + dataset[0]["image"].shape + # Output: torch.Size([3, 256, 256]) + """ + + def __init__( + self, + task: TaskType, + split: Split, + root: Path | str = "./datasets/avenue", + gt_dir: Path | str = "./datasets/avenue/ground_truth_demo", + clip_length_in_frames: int = 2, + frames_between_clips: int = 1, + transform: Transform | None = None, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + ) -> None: + super().__init__( + task=task, + clip_length_in_frames=clip_length_in_frames, + frames_between_clips=frames_between_clips, + target_frame=target_frame, + transform=transform, + ) + + self.root = root if isinstance(root, Path) else Path(root) + self.gt_dir = gt_dir if isinstance(gt_dir, Path) else Path(gt_dir) + self.split = split + self.indexer_cls: Callable = AvenueClipsIndexer + self.samples = make_avenue_dataset(self.root, self.gt_dir, self.split) + + +class Avenue(AnomalibVideoDataModule): + """Avenue DataModule class. + + Args: + root (Path | str): Path to the root of the dataset + Defaults to ``./datasets/avenue``. + gt_dir (Path | str): Path to the ground truth files + Defaults to ``./datasets/avenue/ground_truth_demo``. + clip_length_in_frames (int, optional): Number of video frames in each clip. + Defaults to ``2``. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + Defaults to ``1``. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval + Defaults to ``VideoTargetFrame.LAST``. + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + + Examples: + To create a DataModule for Avenue dataset with default parameters: + + .. code-block:: python + + datamodule = Avenue() + datamodule.setup() + + i, data = next(enumerate(datamodule.train_dataloader())) + data.keys() + # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + + i, data = next(enumerate(datamodule.test_dataloader())) + data.keys() + # Output: dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) + + data["image"].shape + # Output: torch.Size([32, 2, 3, 256, 256]) + + Note that the default task type is segmentation and the dataloader returns a mask in addition to the input. + Also, it is important to note that the dataloader returns a batch of clips, where each clip is a sequence of + frames. The number of frames in each clip is determined by the ``clip_length_in_frames`` parameter. The + ``frames_between_clips`` parameter determines the number of frames between each consecutive clip. The + ``target_frame`` parameter determines which frame in the clip is used for ground truth retrieval. For example, + if ``clip_length_in_frames=2``, ``frames_between_clips=1`` and ``target_frame=VideoTargetFrame.LAST``, then the + dataloader will return a batch of clips where each clip contains two consecutive frames from the video. The + second frame in each clip will be used as the ground truth for the first frame in the clip. The following code + shows how to create a dataloader for classification: + + .. code-block:: python + + datamodule = Avenue( + task="classification", + clip_length_in_frames=2, + frames_between_clips=1, + target_frame=VideoTargetFrame.LAST + ) + datamodule.setup() + + i, data = next(enumerate(datamodule.train_dataloader())) + data.keys() + # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + + data["image"].shape + # Output: torch.Size([32, 2, 3, 256, 256]) + + """ + + def __init__( + self, + root: Path | str = "./datasets/avenue", + gt_dir: Path | str = "./datasets/avenue/ground_truth_demo", + clip_length_in_frames: int = 2, + frames_between_clips: int = 1, + target_frame: VideoTargetFrame | str = VideoTargetFrame.LAST, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.gt_dir = Path(gt_dir) + self.clip_length_in_frames = clip_length_in_frames + self.frames_between_clips = frames_between_clips + self.target_frame = VideoTargetFrame(target_frame) + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = AvenueDataset( + task=self.task, + transform=self.train_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + gt_dir=self.gt_dir, + split=Split.TRAIN, + ) + + self.test_data = AvenueDataset( + task=self.task, + transform=self.eval_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + gt_dir=self.gt_dir, + split=Split.TEST, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = Avenue() + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── avenue + ├── ground_truth_demo + │ ├── ground_truth_show.m + │ ├── Readme.txt + │ ├── testing_label_mask + │ └── testing_videos + ├── testing_videos + │ ├── ... + │ └── 21.avi + ├── testing_vol + │ ├── ... + │ └── vol21.mat + ├── training_videos + │ ├── ... + │ └── 16.avi + └── training_vol + ├── ... + └── vol16.mat + """ + if self.root.is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DATASET_DOWNLOAD_INFO) + download_and_extract(self.gt_dir, ANNOTATIONS_DOWNLOAD_INFO) + + # move contents to root + folder_names = ["Avenue Dataset", "ground_truth_demo"] + for root, folder_name in zip([self.root, self.gt_dir], folder_names, strict=True): + extracted_folder = root / folder_name + for filename in extracted_folder.glob("*"): + move(str(filename), str(root / filename.name)) + extracted_folder.rmdir() + + # convert masks + self._convert_masks(self.gt_dir) + + @staticmethod + def _convert_masks(gt_dir: Path) -> None: + """Convert mask files to .png. + + The masks in the Avenue datasets are provided as matlab (.mat) files. To speed up data loading, we convert the + masks into a sepaarte .png file for every video frame in the dataset. + + Args: + gt_dir (Path): Ground truth folder of the dataset. + """ + # convert masks to numpy + masks_dir = gt_dir / "testing_label_mask" + # get file names + mat_files = list(masks_dir.glob("*.mat")) + mask_folders = [matfile.with_suffix("") for matfile in mat_files] + if not all(folder.exists() for folder in mask_folders): + # convert mask files to images + logger.info("converting mat files to .png format.") + for mat_file, mask_folder in zip(mat_files, mask_folders, strict=True): + mat = scipy.io.loadmat(mat_file) + mask_folder.mkdir(parents=True, exist_ok=True) + masks = mat["volLabel"].squeeze() + for idx, mask in enumerate(masks): + filename = (mask_folder / str(idx).zfill(int(math.log10(len(masks)) + 1))).with_suffix(".png") + cv2.imwrite(str(filename), mask) diff --git a/anomalib/data/video/shanghaitech.py b/anomalib/data/video/shanghaitech.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0055dd312d5698673a7bfd6fe2cfa3a7bea6a7 --- /dev/null +++ b/anomalib/data/video/shanghaitech.py @@ -0,0 +1,350 @@ +"""ShanghaiTech Campus Dataset. + +Description: + This module contains PyTorch Dataset and PyTorch + Lightning DataModule for the ShanghaiTech Campus dataset. + If the dataset is not on the file system, the DataModule class downloads and + extracts the dataset and converts video files to a format that is readable by pyav. + +License: + ShanghaiTech Campus Dataset is released under the BSD 2-Clause License. + +Reference: + - W. Liu and W. Luo, D. Lian and S. Gao. "Future Frame Prediction for Anomaly Detection -- A New Baseline." + IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 2018. +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from shutil import move +from typing import Any + +import numpy as np +import pandas as pd +import torch +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame +from anomalib.data.utils import ( + DownloadInfo, + Split, + ValSplitMode, + download_and_extract, + read_image, + validate_path, +) +from anomalib.data.utils.video import ClipsIndexer, convert_video + +logger = logging.getLogger(__name__) + +DATASET_DOWNLOAD_INFO = DownloadInfo( + name="ShanghaiTech Dataset", + url="http://101.32.75.151:8181/dataset/shanghaitech.tar.gz", + hashsum="c13a827043b259ccf8493c9d9130486872992153a9d714fe229e523cd4c94116", +) + + +def make_shanghaitech_dataset(root: Path, scene: int, split: Split | str | None = None) -> DataFrame: + """Create ShanghaiTech dataset by parsing the file structure. + + The files are expected to follow the structure: + path/to/dataset/[training_videos|testing_videos]/video_filename.avi + path/to/ground_truth/mask_filename.mat + + Args: + root (Path): Path to dataset + scene (int): Index of the dataset scene (category) in range [1, 13] + split (Split | str | None, optional): Dataset split (ie., either train or test). Defaults to None. + + Example: + The following example shows how to get testing samples from ShanghaiTech dataset: + + >>> root = Path('./shanghaiTech') + >>> scene = 1 + >>> samples = make_avenue_dataset(path, scene, split='test') + >>> samples.head() + root image_path split mask_path + 0 shanghaitech shanghaitech/testing/frames/01_0014 test shanghaitech/testing/test_pixel_mask/01_0014.npy + 1 shanghaitech shanghaitech/testing/frames/01_0015 test shanghaitech/testing/test_pixel_mask/01_0015.npy + ... + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + """ + scene_prefix = str(scene).zfill(2) + + # get paths to training videos + root = validate_path(root) + train_root = root / "training/converted_videos" + train_list = [(str(train_root),) + filename.parts[-2:] for filename in train_root.glob(f"{scene_prefix}_*.avi")] + train_samples = DataFrame(train_list, columns=["root", "folder", "image_path"]) + train_samples["split"] = "train" + + # get paths to testing folders + test_root = Path(root) / "testing/frames" + test_folders = [filename for filename in sorted(test_root.glob(f"{scene_prefix}_*")) if filename.is_dir()] + test_folders = [folder for folder in test_folders if len(list(folder.glob("*.jpg"))) > 0] + test_list = [(str(test_root),) + folder.parts[-2:] for folder in test_folders] + test_samples = DataFrame(test_list, columns=["root", "folder", "image_path"]) + test_samples["split"] = "test" + + samples = pd.concat([train_samples, test_samples], ignore_index=True) + + gt_root = Path(root) / "testing/test_pixel_mask" + samples["mask_path"] = "" + samples.loc[samples.root == str(test_root), "mask_path"] = ( + str(gt_root) + "/" + samples.image_path.str.split(".").str[0] + ".npy" + ) + + samples["image_path"] = samples.root + "/" + samples.image_path + + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class ShanghaiTechTrainClipsIndexer(ClipsIndexer): + """Clips indexer for ShanghaiTech dataset. + + The train and test subsets of the ShanghaiTech dataset use different file formats, so separate + clips indexer implementations are needed. + """ + + def get_mask(self, idx: int) -> torch.Tensor | None: + """No masks available for training set.""" + del idx # Unused argument + return None + + +class ShanghaiTechTestClipsIndexer(ClipsIndexer): + """Clips indexer for the test set of the ShanghaiTech Campus dataset. + + The train and test subsets of the ShanghaiTech dataset use different file formats, so separate + clips indexer implementations are needed. + """ + + def get_mask(self, idx: int) -> torch.Tensor | None: + """Retrieve the masks from the file system.""" + video_idx, frames_idx = self.get_clip_location(idx) + mask_file = self.mask_paths[video_idx] + if mask_file == "": # no gt masks available for this clip + return None + frames = self.clips[video_idx][frames_idx] + + vid_masks = np.load(mask_file) + return torch.tensor(np.take(vid_masks, frames, 0)) + + def _compute_frame_pts(self) -> None: + """Retrieve the number of frames in each video.""" + self.video_pts = [] + for video_path in self.video_paths: + n_frames = len(list(Path(video_path).glob("*.jpg"))) + self.video_pts.append(torch.Tensor(range(n_frames))) + + self.video_fps = [None] * len(self.video_paths) # fps information cannot be inferred from folder structure + + def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]: + """Get a subclip from a list of videos. + + Args: + idx (int): index of the subclip. Must be between 0 and num_clips(). + + Returns: + video (torch.Tensor) + audio (torch.Tensor) + info (Dict) + video_idx (int): index of the video in `video_paths` + """ + if idx >= self.num_clips(): + msg = f"Index {idx} out of range ({self.num_clips()} number of clips)" + raise IndexError(msg) + video_idx, clip_idx = self.get_clip_location(idx) + video_path = self.video_paths[video_idx] + clip_pts = self.clips[video_idx][clip_idx] + + frames = sorted(Path(video_path).glob("*.jpg")) + + frame_paths = [frames[pt] for pt in clip_pts.int()] + video = torch.stack([read_image(frame_path, as_tensor=True) for frame_path in frame_paths]) + + return video, torch.empty((1, 0)), {}, video_idx + + +class ShanghaiTechDataset(AnomalibVideoDataset): + """ShanghaiTech Dataset class. + + Args: + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST + root (Path | str): Path to the root of the dataset + scene (int): Index of the dataset scene (category) in range [1, 13] + clip_length_in_frames (int, optional): Number of video frames in each clip. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + """ + + def __init__( + self, + task: TaskType, + split: Split, + root: Path | str = "./datasets/shanghaitech", + scene: int = 1, + clip_length_in_frames: int = 2, + frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + transform: Transform | None = None, + ) -> None: + super().__init__( + task=task, + clip_length_in_frames=clip_length_in_frames, + frames_between_clips=frames_between_clips, + target_frame=target_frame, + transform=transform, + ) + + self.root = Path(root) + self.scene = scene + self.split = split + self.indexer_cls = ShanghaiTechTrainClipsIndexer if self.split == Split.TRAIN else ShanghaiTechTestClipsIndexer + self.samples = make_shanghaitech_dataset(self.root, self.scene, self.split) + + +class ShanghaiTech(AnomalibVideoDataModule): + """ShanghaiTech DataModule class. + + Args: + root (Path | str): Path to the root of the dataset + scene (int): Index of the dataset scene (category) in range [1, 13] + clip_length_in_frames (int, optional): Number of video frames in each clip. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + train_batch_size (int, optional): Training batch size. Defaults to 32. + eval_batch_size (int, optional): Test batch size. Defaults to 32. + num_workers (int, optional): Number of workers. Defaults to 8. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + """ + + def __init__( + self, + root: Path | str = "./datasets/shanghaitech", + scene: int = 1, + clip_length_in_frames: int = 2, + frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.scene = scene + + self.clip_length_in_frames = clip_length_in_frames + self.frames_between_clips = frames_between_clips + self.target_frame = target_frame + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = ShanghaiTechDataset( + task=self.task, + transform=self.train_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + scene=self.scene, + split=Split.TRAIN, + ) + + self.test_data = ShanghaiTechDataset( + task=self.task, + transform=self.eval_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + scene=self.scene, + split=Split.TEST, + ) + + def prepare_data(self) -> None: + """Download the dataset and convert video files.""" + training_root = self.root / "training" + if training_root.is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DATASET_DOWNLOAD_INFO) + + # move contents to root + extracted_folder = self.root / "shanghaitech" + for filename in extracted_folder.glob("*"): + move(str(filename), str(self.root / filename.name)) + extracted_folder.rmdir() + + # convert images if not done already + vid_dir = training_root / "videos" + converted_vid_dir = training_root / "converted_videos" + vid_count = len(list(vid_dir.glob("*"))) + converted_vid_count = len(list(converted_vid_dir.glob("*"))) + if vid_count != converted_vid_count: + self._convert_training_videos(vid_dir, converted_vid_dir) + + @staticmethod + def _convert_training_videos(video_folder: Path, target_folder: Path) -> None: + """Re-code the training videos to ensure correct reading of frames by torchvision. + + The encoding of the raw video files in the ShanghaiTech dataset causes some problems when + reading the frames using pyav. To prevent this, we read the frames from the video files using opencv, + and write them to a new video file that can be parsed correctly with pyav. + + Args: + video_folder (Path): Path to the folder of training videos. + target_folder (Path): File system location where the converted videos will be stored. + """ + training_videos = sorted(video_folder.glob("*")) + for video_idx, video_path in enumerate(training_videos): + logger.info("Converting training video %s (%i/%i)...", video_path.name, video_idx + 1, len(training_videos)) + file_name = video_path.name + target_path = target_folder / file_name + convert_video(video_path, target_path, codec="XVID") diff --git a/anomalib/data/video/ucsd_ped.py b/anomalib/data/video/ucsd_ped.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7850ecda300869a577a7603c8f00711316f3e9 --- /dev/null +++ b/anomalib/data/video/ucsd_ped.py @@ -0,0 +1,289 @@ +"""UCSD Pedestrian dataset.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from shutil import move +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame +from anomalib.data.utils import ( + DownloadInfo, + Split, + ValSplitMode, + download_and_extract, + read_image, + read_mask, + validate_path, +) +from anomalib.data.utils.video import ClipsIndexer + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = logging.getLogger(__name__) + +DOWNLOAD_INFO = DownloadInfo( + name="UCSD Pedestrian", + url="http://www.svcl.ucsd.edu/projects/anomaly/UCSD_Anomaly_Dataset.tar.gz", + hashsum="2329af326951f5097fdd114c50e853957d3e569493a49d22fc082a9fd791915b", +) + +CATEGORIES = ("UCSDped1", "UCSDped2") + + +def make_ucsd_dataset(path: Path, split: str | Split | None = None) -> DataFrame: + """Create UCSD Pedestrian dataset by parsing the file structure. + + The files are expected to follow the structure: + path/to/dataset/category/split/video_id/image_filename.tif + path/to/dataset/category/split/video_id_gt/mask_filename.bmp + + Args: + path (Path): Path to dataset + split (str | Split | None, optional): Dataset split (ie., either train or test). Defaults to None. + + Example: + The following example shows how to get testing samples from UCSDped2 category: + + >>> root = Path('./UCSDped') + >>> category = 'UCSDped2' + >>> path = root / category + >>> path + PosixPath('UCSDped/UCSDped2') + + >>> samples = make_ucsd_dataset(path, split='test') + >>> samples.head() + root folder image_path mask_path split + 0 UCSDped/UCSDped2 Test UCSDped/UCSDped2/Test/Test001 UCSDped/UCSDped2/Test/Test001_gt test + 1 UCSDped/UCSDped2 Test UCSDped/UCSDped2/Test/Test002 UCSDped/UCSDped2/Test/Test002_gt test + ... + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + """ + path = validate_path(path) + folders = [filename for filename in sorted(path.glob("*/*")) if filename.is_dir()] + folders = [folder for folder in folders if list(folder.glob("*.tif"))] + + samples_list = [(str(path),) + folder.parts[-2:] for folder in folders] + samples = DataFrame(samples_list, columns=["root", "folder", "image_path"]) + + samples.loc[samples.folder == "Test", "mask_path"] = samples.image_path.str.split(".").str[0] + "_gt" + samples.loc[samples.folder == "Test", "mask_path"] = samples.root + "/" + samples.folder + "/" + samples.mask_path + samples.loc[samples.folder == "Train", "mask_path"] = "" + + samples["image_path"] = samples.root + "/" + samples.folder + "/" + samples.image_path + + samples.loc[samples.folder == "Train", "split"] = "train" + samples.loc[samples.folder == "Test", "split"] = "test" + + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class UCSDpedClipsIndexer(ClipsIndexer): + """Clips class for UCSDped dataset.""" + + def get_mask(self, idx: int) -> np.ndarray | None: + """Retrieve the masks from the file system.""" + video_idx, frames_idx = self.get_clip_location(idx) + mask_folder = self.mask_paths[video_idx] + if mask_folder == "": # no gt masks available for this clip + return None + frames = self.clips[video_idx][frames_idx] + + mask_frames = sorted(Path(mask_folder).glob("*.bmp")) + mask_paths = [mask_frames[idx] for idx in frames.int()] + + return torch.stack([read_mask(mask_path, as_tensor=True) for mask_path in mask_paths]) + + def _compute_frame_pts(self) -> None: + """Retrieve the number of frames in each video.""" + self.video_pts = [] + for video_path in self.video_paths: + n_frames = len(list(Path(video_path).glob("*.tif"))) + self.video_pts.append(torch.Tensor(range(n_frames))) + + self.video_fps = [None] * len(self.video_paths) # fps information cannot be inferred from folder structure + + def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]: + """Get a subclip from a list of videos. + + Args: + idx (int): index of the subclip. Must be between 0 and num_clips(). + + Returns: + video (torch.Tensor) + audio (torch.Tensor) + info (dict) + video_idx (int): index of the video in `video_paths` + """ + if idx >= self.num_clips(): + msg = f"Index {idx} out of range ({self.num_clips()} number of clips)" + raise IndexError(msg) + video_idx, clip_idx = self.get_clip_location(idx) + video_path = self.video_paths[video_idx] + clip_pts = self.clips[video_idx][clip_idx] + + frames = sorted(Path(video_path).glob("*.tif")) + + frame_paths = [frames[pt] for pt in clip_pts.int()] + video = torch.stack([read_image(frame_path, as_tensor=True) for frame_path in frame_paths]) + + return video, torch.empty((1, 0)), {}, video_idx + + +class UCSDpedDataset(AnomalibVideoDataset): + """UCSDped Dataset class. + + Args: + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + root (Path | str): Path to the root of the dataset + category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2" + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + clip_length_in_frames (int, optional): Number of video frames in each clip. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + """ + + def __init__( + self, + task: TaskType, + root: str | Path, + category: str, + split: Split, + clip_length_in_frames: int = 2, + frames_between_clips: int = 10, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + transform: Transform | None = None, + ) -> None: + super().__init__( + task=task, + clip_length_in_frames=clip_length_in_frames, + frames_between_clips=frames_between_clips, + target_frame=target_frame, + transform=transform, + ) + + self.root_category = Path(root) / category + self.split = split + self.indexer_cls: Callable = UCSDpedClipsIndexer + self.samples = make_ucsd_dataset(self.root_category, self.split) + + +class UCSDped(AnomalibVideoDataModule): + """UCSDped DataModule class. + + Args: + root (Path | str): Path to the root of the dataset + category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2" + clip_length_in_frames (int, optional): Number of video frames in each clip. + frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval + task (TaskType): Task type, 'classification', 'detection' or 'segmentation' + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + train_batch_size (int, optional): Training batch size. Defaults to 32. + eval_batch_size (int, optional): Test batch size. Defaults to 32. + num_workers (int, optional): Number of workers. Defaults to 8. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + """ + + def __init__( + self, + root: Path | str = "./datasets/ucsd", + category: str = "UCSDped2", + clip_length_in_frames: int = 2, + frames_between_clips: int = 10, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + train_batch_size: int = 8, + eval_batch_size: int = 8, + num_workers: int = 8, + val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.task = TaskType(task) + self.root = Path(root) + self.category = category + + self.clip_length_in_frames = clip_length_in_frames + self.frames_between_clips = frames_between_clips + self.target_frame = VideoTargetFrame(target_frame) + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = UCSDpedDataset( + task=self.task, + transform=self.train_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + category=self.category, + split=Split.TRAIN, + ) + + self.test_data = UCSDpedDataset( + task=self.task, + transform=self.eval_transform, + clip_length_in_frames=self.clip_length_in_frames, + frames_between_clips=self.frames_between_clips, + target_frame=self.target_frame, + root=self.root, + category=self.category, + split=Split.TEST, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available.""" + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) + + # move contents to root + extracted_folder = self.root / "UCSD_Anomaly_Dataset.v1p2" + for filename in extracted_folder.glob("*"): + move(str(filename), str(self.root / filename.name)) + extracted_folder.rmdir() diff --git a/anomalib/deploy/__init__.py b/anomalib/deploy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45581bd8dda9ba800375f4552d22442bc6147f22 --- /dev/null +++ b/anomalib/deploy/__init__.py @@ -0,0 +1,9 @@ +"""Functions for Inference and model deployment.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .export import ExportType +from .inferencers import Inferencer, OpenVINOInferencer, TorchInferencer + +__all__ = ["Inferencer", "OpenVINOInferencer", "TorchInferencer", "ExportType"] diff --git a/anomalib/deploy/export.py b/anomalib/deploy/export.py new file mode 100644 index 0000000000000000000000000000000000000000..2430413fbcdb3337bde4d6815999cf832d976640 --- /dev/null +++ b/anomalib/deploy/export.py @@ -0,0 +1,86 @@ +"""Utilities for optimization and OpenVINO conversion.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from enum import Enum + +import torch +from torch import nn +from torchvision.transforms.v2 import CenterCrop, Compose, Resize, Transform + +from anomalib.data.transforms import ExportableCenterCrop + +logger = logging.getLogger("anomalib") + + +class ExportType(str, Enum): + """Model export type. + + Examples: + >>> from anomalib.deploy import ExportType + >>> ExportType.ONNX + 'onnx' + >>> ExportType.OPENVINO + 'openvino' + >>> ExportType.TORCH + 'torch' + """ + + ONNX = "onnx" + OPENVINO = "openvino" + TORCH = "torch" + + +class InferenceModel(nn.Module): + """Inference model for export. + + The InferenceModel is used to wrap the model and transform for exporting to torch and ONNX/OpenVINO. + + Args: + model (nn.Module): Model to export. + transform (Transform): Input transform for the model. + disable_antialias (bool, optional): Disable antialiasing in the Resize transforms of the given transform. This + is needed for ONNX/OpenVINO export, as antialiasing is not supported in the ONNX opset. + """ + + def __init__(self, model: nn.Module, transform: Transform, disable_antialias: bool = False) -> None: + super().__init__() + self.model = model + self.transform = transform + self.convert_center_crop() + if disable_antialias: + self.disable_antialias() + + def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Transform the input batch and pass it through the model.""" + batch = self.transform(batch) + return self.model(batch) + + def disable_antialias(self) -> None: + """Disable antialiasing in the Resize transforms of the given transform. + + This is needed for ONNX/OpenVINO export, as antialiasing is not supported in the ONNX opset. + """ + if isinstance(self.transform, Resize): + self.transform.antialias = False + if isinstance(self.transform, Compose): + for transform in self.transform.transforms: + if isinstance(transform, Resize): + transform.antialias = False + + def convert_center_crop(self) -> None: + """Convert CenterCrop to ExportableCenterCrop for ONNX export. + + The original CenterCrop transform is not supported in ONNX export. This method replaces the CenterCrop to + ExportableCenterCrop, which is supported in ONNX export. For more details, see the implementation of + ExportableCenterCrop. + """ + if isinstance(self.transform, CenterCrop): + self.transform = ExportableCenterCrop(size=self.transform.size) + elif isinstance(self.transform, Compose): + transforms = self.transform.transforms + for index in range(len(transforms)): + if isinstance(transforms[index], CenterCrop): + transforms[index] = ExportableCenterCrop(size=transforms[index].size) diff --git a/anomalib/deploy/inferencers/__init__.py b/anomalib/deploy/inferencers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f47ece342516a5da3c44984d69c460ffd89f3a50 --- /dev/null +++ b/anomalib/deploy/inferencers/__init__.py @@ -0,0 +1,10 @@ +"""Inferencers for Torch and OpenVINO.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .base_inferencer import Inferencer +from .openvino_inferencer import OpenVINOInferencer +from .torch_inferencer import TorchInferencer + +__all__ = ["Inferencer", "OpenVINOInferencer", "TorchInferencer"] diff --git a/anomalib/deploy/inferencers/base_inferencer.py b/anomalib/deploy/inferencers/base_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..a76dcafb06cbcf950cc6ffc0bba7e4f32e79410b --- /dev/null +++ b/anomalib/deploy/inferencers/base_inferencer.py @@ -0,0 +1,136 @@ +"""Base Inferencer for Torch and OpenVINO.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, cast + +import cv2 +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +from skimage.morphology import dilation +from skimage.segmentation import find_boundaries + +from anomalib.utils.normalization.min_max import normalize as normalize_min_max +from anomalib.utils.post_processing import compute_mask +from anomalib.utils.visualization import ImageResult + + +class Inferencer(ABC): + """Abstract class for the inference. + + This is used by both Torch and OpenVINO inference. + """ + + @abstractmethod + def load_model(self, path: str | Path) -> Any: # noqa: ANN401 + """Load Model.""" + raise NotImplementedError + + @abstractmethod + def pre_process(self, image: np.ndarray) -> np.ndarray | torch.Tensor: + """Pre-process.""" + raise NotImplementedError + + @abstractmethod + def forward(self, image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: + """Forward-Pass input to model.""" + raise NotImplementedError + + @abstractmethod + def post_process(self, predictions: np.ndarray | torch.Tensor, metadata: dict[str, Any] | None) -> dict[str, Any]: + """Post-Process.""" + raise NotImplementedError + + @abstractmethod + def predict(self, image: str | Path | np.ndarray | torch.Tensor) -> ImageResult: + """Predict.""" + raise NotImplementedError + + @staticmethod + def _superimpose_segmentation_mask(metadata: dict, anomaly_map: np.ndarray, image: np.ndarray) -> np.ndarray: + """Superimpose segmentation mask on top of image. + + Args: + metadata (dict): Metadata of the image which contains the image size. + anomaly_map (np.ndarray): Anomaly map which is used to extract segmentation mask. + image (np.ndarray): Image on which segmentation mask is to be superimposed. + + Returns: + np.ndarray: Image with segmentation mask superimposed. + """ + pred_mask = compute_mask(anomaly_map, 0.5) # assumes predictions are normalized. + image_height = metadata["image_shape"][0] + image_width = metadata["image_shape"][1] + pred_mask = cv2.resize(pred_mask, (image_width, image_height)) + boundaries = find_boundaries(pred_mask) + outlines = dilation(boundaries, np.ones((7, 7))) + image[outlines] = [255, 0, 0] + return image + + def __call__(self, image: np.ndarray) -> ImageResult: + """Call predict on the Image. + + Args: + image (np.ndarray): Input Image + + Returns: + ImageResult: Prediction results to be visualized. + """ + return self.predict(image) + + @staticmethod + def _normalize( + pred_scores: torch.Tensor | np.float32, + metadata: dict | DictConfig, + anomaly_maps: torch.Tensor | np.ndarray | None = None, + ) -> tuple[np.ndarray | torch.Tensor | None, float]: + """Apply normalization and resizes the image. + + Args: + pred_scores (Tensor | np.float32): Predicted anomaly score + metadata (dict | DictConfig): Meta data. Post-processing step sometimes requires + additional meta data such as image shape. This variable comprises such info. + anomaly_maps (Tensor | np.ndarray | None): Predicted raw anomaly map. + + Returns: + tuple[np.ndarray | torch.Tensor | None, float]: Post processed predictions that are ready to be + visualized and predicted scores. + """ + # min max normalization + if "min" in metadata and "max" in metadata: + if anomaly_maps is not None: + anomaly_maps = normalize_min_max( + anomaly_maps, + metadata["pixel_threshold"], + metadata["min"], + metadata["max"], + ) + pred_scores = normalize_min_max( + pred_scores, + metadata["image_threshold"], + metadata["min"], + metadata["max"], + ) + + return anomaly_maps, float(pred_scores) + + def _load_metadata(self, path: str | Path | dict | None = None) -> dict | DictConfig: + """Load the meta data from the given path. + + Args: + path (str | Path | dict | None, optional): Path to JSON file containing the metadata. + If no path is provided, it returns an empty dict. Defaults to None. + + Returns: + dict | DictConfig: Dictionary containing the metadata. + """ + metadata: dict[str, float | np.ndarray | torch.Tensor] | DictConfig = {} + if path is not None: + config = OmegaConf.load(path) + metadata = cast(DictConfig, config) + return metadata diff --git a/anomalib/deploy/inferencers/openvino_inferencer.py b/anomalib/deploy/inferencers/openvino_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..3206b39e302801e9cf23747fa4465960008b26e1 --- /dev/null +++ b/anomalib/deploy/inferencers/openvino_inferencer.py @@ -0,0 +1,341 @@ +"""OpenVINO Inferencer implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import cv2 +import numpy as np +from omegaconf import DictConfig +from PIL import Image + +from anomalib import TaskType +from anomalib.data.utils.label import LabelName +from anomalib.utils.visualization import ImageResult + +from .base_inferencer import Inferencer + +logger = logging.getLogger("anomalib") + +if find_spec("openvino") is not None: + import openvino.runtime as ov + + if TYPE_CHECKING: + from openvino.runtime import CompiledModel +else: + logger.warning("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.") + + +class OpenVINOInferencer(Inferencer): + """OpenVINO implementation for the inference. + + Args: + path (str | Path): Path to the openvino onnx, xml or bin file. + metadata (str | Path | dict, optional): Path to metadata file or a dict object defining the + metadata. + Defaults to ``None``. + device (str | None, optional): Device to run the inference on (AUTO, CPU, GPU, NPU). + Defaults to ``AUTO``. + task (TaskType | None, optional): Task type. + Defaults to ``None``. + config (dict | None, optional): Configuration parameters for the inference + Defaults to ``None``. + + Examples: + Assume that we have an OpenVINO IR model and metadata files in the following structure: + + .. code-block:: bash + + $ tree weights + ./weights + ├── model.bin + ├── model.xml + └── metadata.json + + We could then create ``OpenVINOInferencer`` as follows: + + >>> from anomalib.deploy.inferencers import OpenVINOInferencer + >>> inferencer = OpenVINOInferencer( + ... path="weights/model.xml", + ... metadata="weights/metadata.json", + ... device="CPU", + ... ) + + This will ensure that the model is loaded on the ``CPU`` device and the + metadata is loaded from the ``metadata.json`` file. To make a prediction, + we can simply call the ``predict`` method: + + >>> prediction = inferencer.predict(image="path/to/image.jpg") + + Alternatively we can also pass the image as a PIL image or numpy array: + + >>> from PIL import Image + >>> image = Image.open("path/to/image.jpg") + >>> prediction = inferencer.predict(image=image) + + >>> import numpy as np + >>> image = np.random.rand(224, 224, 3) + >>> prediction = inferencer.predict(image=image) + + ``prediction`` will be an ``ImageResult`` object containing the prediction + results. For example, to visualize the heatmap, we can do the following: + + >>> from matplotlib import pyplot as plt + >>> plt.imshow(result.heatmap) + + It is also possible to visualize the true and predicted masks if the + task is ``TaskType.SEGMENTATION``: + + >>> plt.imshow(result.gt_mask) + >>> plt.imshow(result.pred_mask) + """ + + def __init__( + self, + path: str | Path | tuple[bytes, bytes], + metadata: str | Path | dict | None = None, + device: str | None = "AUTO", + task: str | None = None, + config: dict | None = None, + ) -> None: + self.device = device + + self.config = config + self.input_blob, self.output_blob, self.model = self.load_model(path) + self.metadata = super()._load_metadata(metadata) + + self.task = TaskType(task) if task else TaskType(self.metadata["task"]) + + def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, "CompiledModel"]: + """Load the OpenVINO model. + + Args: + path (str | Path | tuple[bytes, bytes]): Path to the onnx or xml and bin files + or tuple of .xml and .bin data as bytes. + + Returns: + [tuple[str, str, ExecutableNetwork]]: Input and Output blob names + together with the Executable network. + """ + core = ov.Core() + # If tuple of bytes is passed + if isinstance(path, tuple): + model = core.read_model(model=path[0], weights=path[1]) + else: + path = path if isinstance(path, Path) else Path(path) + if path.suffix in (".bin", ".xml"): + if path.suffix == ".bin": + bin_path, xml_path = path, path.with_suffix(".xml") + elif path.suffix == ".xml": + xml_path, bin_path = path, path.with_suffix(".bin") + model = core.read_model(xml_path, bin_path) + elif path.suffix == ".onnx": + model = core.read_model(path) + else: + msg = f"Path must be .onnx, .bin or .xml file. Got {path.suffix}" + raise ValueError(msg) + # Create cache folder + cache_folder = Path("cache") + cache_folder.mkdir(exist_ok=True) + core.set_property({"CACHE_DIR": cache_folder}) + + compile_model = core.compile_model(model=model, device_name=self.device, config=self.config) + + input_blob = compile_model.input(0) + output_blob = compile_model.output(0) + + return input_blob, output_blob, compile_model + + def pre_process(self, image: np.ndarray) -> np.ndarray: + """Pre-process the input image by applying transformations. + + Args: + image (np.ndarray): Input image. + + Returns: + np.ndarray: pre-processed image. + """ + processed_image = image + + if len(processed_image.shape) == 3: + processed_image = np.expand_dims(processed_image, axis=0) + + if processed_image.shape[-1] == 3: + processed_image = processed_image.transpose(0, 3, 1, 2) + + return processed_image + + def predict( + self, + image: str | Path | np.ndarray, + metadata: dict[str, Any] | None = None, + ) -> ImageResult: + """Perform a prediction for a given input image. + + The main workflow is (i) pre-processing, (ii) forward-pass, (iii) post-process. + + Args: + image (Union[str, np.ndarray]): Input image whose output is to be predicted. + It could be either a path to image or numpy array itself. + + metadata: Metadata information such as shape, threshold. + + Returns: + ImageResult: Prediction results to be visualized. + """ + # Convert file path or string to image if necessary + if isinstance(image, str | Path): + image = Image.open(image) + + # Convert PIL image to numpy array + if isinstance(image, Image.Image): + image = np.array(image, dtype=np.float32) + if not isinstance(image, np.ndarray): + msg = f"Input image must be a numpy array or a path to an image. Got {type(image)}" + raise TypeError(msg) + + # Resize image to model input size if not dynamic + if self.input_blob.partial_shape[2].is_static and self.input_blob.partial_shape[3].is_static: + image = cv2.resize(image, tuple(list(self.input_blob.shape)[2:][::-1])) + + # Normalize numpy array to range [0, 1] + if image.dtype != np.float32: + image = image.astype(np.float32) + if image.max() > 1.0: + image /= 255.0 + + # Check if metadata is provided, if not use the default metadata. + if metadata is None: + metadata = self.metadata if hasattr(self, "metadata") else {} + metadata["image_shape"] = image.shape[:2] + + processed_image = self.pre_process(image) + predictions = self.forward(processed_image) + output = self.post_process(predictions, metadata=metadata) + + return ImageResult( + image=(image * 255).astype(np.uint8), + pred_score=output["pred_score"], + pred_label=output["pred_label"], + anomaly_map=output["anomaly_map"], + pred_mask=output["pred_mask"], + pred_boxes=output["pred_boxes"], + box_labels=output["box_labels"], + ) + + def forward(self, image: np.ndarray) -> np.ndarray: + """Forward-Pass input tensor to the model. + + Args: + image (np.ndarray): Input tensor. + + Returns: + np.ndarray: Output predictions. + """ + return self.model(image) + + def post_process(self, predictions: np.ndarray, metadata: dict | DictConfig | None = None) -> dict[str, Any]: + """Post process the output predictions. + + Args: + predictions (np.ndarray): Raw output predicted by the model. + metadata (Dict, optional): Metadata. Post-processing step sometimes requires + additional metadata such as image shape. This variable comprises such info. + Defaults to None. + + Returns: + dict[str, Any]: Post processed prediction results. + """ + if metadata is None: + metadata = self.metadata + + predictions = predictions[self.output_blob] + + # Initialize the result variables. + anomaly_map: np.ndarray | None = None + pred_label: LabelName | None = None + pred_mask: float | None = None + + # If predictions returns a single value, this means that the task is + # classification, and the value is the classification prediction score. + if len(predictions.shape) == 1: + task = TaskType.CLASSIFICATION + pred_score = predictions + else: + task = TaskType.SEGMENTATION + anomaly_map = predictions.squeeze() + pred_score = anomaly_map.reshape(-1).max() + + # Common practice in anomaly detection is to assign anomalous + # label to the prediction if the prediction score is greater + # than the image threshold. + if "image_threshold" in metadata: + pred_idx = pred_score >= metadata["image_threshold"] + pred_label = LabelName.ABNORMAL if pred_idx else LabelName.NORMAL + + if task == TaskType.CLASSIFICATION: + _, pred_score = self._normalize(pred_scores=pred_score, metadata=metadata) + elif task in (TaskType.SEGMENTATION, TaskType.DETECTION): + if "pixel_threshold" in metadata: + pred_mask = (anomaly_map >= metadata["pixel_threshold"]).astype(np.uint8) + + anomaly_map, pred_score = self._normalize( + pred_scores=pred_score, + anomaly_maps=anomaly_map, + metadata=metadata, + ) + if anomaly_map is None: + msg = "Anomaly map cannot be None." + raise ValueError(msg) + + if "image_shape" in metadata and anomaly_map.shape != metadata["image_shape"]: + image_height = metadata["image_shape"][0] + image_width = metadata["image_shape"][1] + anomaly_map = cv2.resize(anomaly_map, (image_width, image_height)) + + if pred_mask is not None: + pred_mask = cv2.resize(pred_mask, (image_width, image_height)) + else: + msg = f"Unknown task type: {task}" + raise ValueError(msg) + + if self.task == TaskType.DETECTION: + pred_boxes = self._get_boxes(pred_mask) + box_labels = np.ones(pred_boxes.shape[0]) + else: + pred_boxes = None + box_labels = None + + return { + "anomaly_map": anomaly_map, + "pred_label": pred_label, + "pred_score": pred_score, + "pred_mask": pred_mask, + "pred_boxes": pred_boxes, + "box_labels": box_labels, + } + + @staticmethod + def _get_boxes(mask: np.ndarray) -> np.ndarray: + """Get bounding boxes from masks. + + Args: + mask (np.ndarray): Input mask of shape (H, W) + + Returns: + np.ndarray: array of shape (N, 4) containing the bounding box coordinates of the objects in the masks + in xyxy format. + """ + _, comps = cv2.connectedComponents(mask) + + labels = np.unique(comps) + boxes = [] + for label in labels[labels != 0]: + y_loc, x_loc = np.where(comps == label) + boxes.append([np.min(x_loc), np.min(y_loc), np.max(x_loc), np.max(y_loc)]) + return np.stack(boxes) if boxes else np.empty((0, 4)) diff --git a/anomalib/deploy/inferencers/torch_inferencer.py b/anomalib/deploy/inferencers/torch_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..8f18f0146da0ab108765876b8361b05ac60abab7 --- /dev/null +++ b/anomalib/deploy/inferencers/torch_inferencer.py @@ -0,0 +1,323 @@ +"""Torch inference implementations.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +import torch +from omegaconf import DictConfig +from torch import nn + +from anomalib import TaskType +from anomalib.data import LabelName +from anomalib.data.utils import read_image +from anomalib.data.utils.boxes import masks_to_boxes +from anomalib.utils.visualization import ImageResult + +from .base_inferencer import Inferencer + + +class TorchInferencer(Inferencer): + """PyTorch implementation for the inference. + + Args: + path (str | Path): Path to Torch model weights. + device (str): Device to use for inference. Options are ``auto``, + ``cpu``, ``cuda``. + Defaults to ``auto``. + + Examples: + Assume that we have a Torch ``pt`` model and metadata files in the + following structure: + + >>> from anomalib.deploy.inferencers import TorchInferencer + >>> inferencer = TorchInferencer(path="path/to/torch/model.pt", device="cpu") + + This will ensure that the model is loaded on the ``CPU`` device. To make + a prediction, we can simply call the ``predict`` method: + + >>> from anomalib.data.utils import read_image + >>> image = read_image("path/to/image.jpg") + >>> result = inferencer.predict(image) + + ``result`` will be an ``ImageResult`` object containing the prediction + results. For example, to visualize the heatmap, we can do the following: + + >>> from matplotlib import pyplot as plt + >>> plt.imshow(result.heatmap) + + It is also possible to visualize the true and predicted masks if the + task is ``TaskType.SEGMENTATION``: + + >>> plt.imshow(result.gt_mask) + >>> plt.imshow(result.pred_mask) + """ + + def __init__( + self, + path: str | Path, + device: str = "auto", + ) -> None: + self.device = self._get_device(device) + + # Load the model weights and metadata + self.checkpoint = self._load_checkpoint(path) + self.model = self.load_model(path) + self.metadata = self._load_metadata(path) + + @staticmethod + def _get_device(device: str) -> torch.device: + """Get the device to use for inference. + + Args: + device (str): Device to use for inference. Options are auto, cpu, cuda. + + Returns: + torch.device: Device to use for inference. + """ + if device not in ("auto", "cpu", "cuda", "gpu"): + msg = f"Unknown device {device}" + raise ValueError(msg) + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + elif device == "gpu": + device = "cuda" + return torch.device(device) + + def _load_checkpoint(self, path: str | Path) -> dict: + """Load the checkpoint. + + Args: + path (str | Path): Path to the torch ckpt file. + + Returns: + dict: Dictionary containing the model and metadata. + """ + if isinstance(path, str): + path = Path(path) + + if path.suffix not in (".pt", ".pth"): + msg = f"Unknown torch checkpoint file format {path.suffix}. Make sure you save the Torch model." + raise ValueError(msg) + + return torch.load(path, map_location=self.device) + + def _load_metadata(self, path: str | Path | dict | None = None) -> dict | DictConfig: + """Load metadata from file. + + Args: + path (str | Path | dict): Path to the model pt file. + + Returns: + dict: Dictionary containing the metadata. + """ + metadata: dict | DictConfig + + if isinstance(path, dict): + metadata = path + elif isinstance(path, str | Path): + checkpoint = self._load_checkpoint(path) + + # Torch model should ideally contain the metadata in the checkpoint. + # Check if the metadata is present in the checkpoint. + if "metadata" not in checkpoint: + msg = ( + "``metadata`` is not found in the checkpoint. Please ensure that you save the model as Torch model." + ) + raise KeyError( + msg, + ) + metadata = checkpoint["metadata"] + else: + msg = f"Unknown ``path`` type {type(path)}" + raise TypeError(msg) + + return metadata + + def load_model(self, path: str | Path) -> nn.Module: + """Load the PyTorch model. + + Args: + path (str | Path): Path to the Torch model. + + Returns: + (nn.Module): Torch model. + """ + checkpoint = self._load_checkpoint(path) + if "model" not in checkpoint: + msg = "``model`` is not found in the checkpoint. Please check the checkpoint file." + raise KeyError(msg) + + model = checkpoint["model"] + model.eval() + return model.to(self.device) + + def predict( + self, + image: str | Path | torch.Tensor, + metadata: dict[str, Any] | None = None, + ) -> ImageResult: + """Perform a prediction for a given input image. + + The main workflow is (i) pre-processing, (ii) forward-pass, (iii) post-process. + + Args: + image (Union[str, np.ndarray]): Input image whose output is to be predicted. + It could be either a path to image or numpy array itself. + + metadata: Metadata information such as shape, threshold. + + Returns: + ImageResult: Prediction results to be visualized. + """ + if metadata is None: + metadata = self.metadata if hasattr(self, "metadata") else {} + if isinstance(image, str | Path): + image = read_image(image, as_tensor=True) + + metadata["image_shape"] = image.shape[-2:] + + processed_image = self.pre_process(image) + predictions = self.forward(processed_image) + output = self.post_process(predictions, metadata=metadata) + + return ImageResult( + image=(image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8), + pred_score=output["pred_score"], + pred_label=output["pred_label"], + anomaly_map=output["anomaly_map"], + pred_mask=output["pred_mask"], + pred_boxes=output["pred_boxes"], + box_labels=output["box_labels"], + ) + + def pre_process(self, image: np.ndarray) -> torch.Tensor: + """Pre process the input image. + + Args: + image (np.ndarray): Input image + + Returns: + Tensor: pre-processed image. + """ + if len(image) == 3: + image = image.unsqueeze(0) + + return image.to(self.device) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """Forward-Pass input tensor to the model. + + Args: + image (torch.Tensor): Input tensor. + + Returns: + Tensor: Output predictions. + """ + return self.model(image) + + def post_process( + self, + predictions: torch.Tensor | list[torch.Tensor] | dict[str, torch.Tensor], + metadata: dict | DictConfig | None = None, + ) -> dict[str, Any]: + """Post process the output predictions. + + Args: + predictions (Tensor | list[torch.Tensor] | dict[str, torch.Tensor]): Raw output predicted by the model. + metadata (dict, optional): Meta data. Post-processing step sometimes requires + additional meta data such as image shape. This variable comprises such info. + Defaults to None. + + Returns: + dict[str, str | float | np.ndarray]: Post processed prediction results. + """ + if metadata is None: + metadata = self.metadata + + # Some models return a Tensor while others return a list or dictionary. Handle both cases. + # TODO(ashwinvaidya17): Wrap this post-processing stage within the model's forward pass. + # CVS-122674 + + # Case I: Predictions could be a tensor. + if isinstance(predictions, torch.Tensor): + anomaly_map = predictions.detach().cpu().numpy() + pred_score = anomaly_map.reshape(-1).max() + + # Case II: Predictions could be a dictionary of tensors. + elif isinstance(predictions, dict): + if "anomaly_map" in predictions: + anomaly_map = predictions["anomaly_map"].detach().cpu().numpy() + else: + msg = "``anomaly_map`` not found in the predictions." + raise KeyError(msg) + + if "pred_score" in predictions: + pred_score = predictions["pred_score"].detach().cpu().numpy() + else: + pred_score = anomaly_map.reshape(-1).max() + + # Case III: Predictions could be a list of tensors. + elif isinstance(predictions, Sequence): + if isinstance(predictions[1], (torch.Tensor)): + pred_score, anomaly_map = predictions + anomaly_map = anomaly_map.detach().cpu().numpy() + pred_score = pred_score.detach().cpu().numpy() + else: + pred_score, anomaly_map = predictions + pred_score = pred_score.detach() + else: + msg = ( + f"Unknown prediction type {type(predictions)}. " + "Expected torch.Tensor, list[torch.Tensor] or dict[str, torch.Tensor]." + ) + raise TypeError(msg) + + # Common practice in anomaly detection is to assign anomalous + # label to the prediction if the prediction score is greater + # than the image threshold. + pred_label: LabelName | None = None + if "image_threshold" in metadata: + pred_idx = pred_score >= metadata["image_threshold"] + pred_label = LabelName.ABNORMAL if pred_idx else LabelName.NORMAL + + pred_mask: np.ndarray | None = None + if "pixel_threshold" in metadata: + pred_mask = (anomaly_map >= metadata["pixel_threshold"]).squeeze().astype(np.uint8) + + anomaly_map = anomaly_map.squeeze() + anomaly_map, pred_score = self._normalize(anomaly_maps=anomaly_map, pred_scores=pred_score, metadata=metadata) + + if isinstance(anomaly_map, torch.Tensor): + anomaly_map = anomaly_map.detach().cpu().numpy() + + if "image_shape" in metadata and anomaly_map.shape != metadata["image_shape"]: + image_height = metadata["image_shape"][0] + image_width = metadata["image_shape"][1] + anomaly_map = cv2.resize(anomaly_map, (image_width, image_height)) + + if pred_mask is not None: + pred_mask = cv2.resize(pred_mask, (image_width, image_height)) + + if self.metadata["task"] == TaskType.DETECTION: + pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0][0].numpy() + box_labels = np.ones(pred_boxes.shape[0]) + else: + pred_boxes = None + box_labels = None + + return { + "anomaly_map": anomaly_map, + "pred_label": pred_label, + "pred_score": pred_score, + "pred_mask": pred_mask, + "pred_boxes": pred_boxes, + "box_labels": box_labels, + } diff --git a/anomalib/engine/__init__.py b/anomalib/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb239f5a55f199e39c4499a74b982b04bbb285c --- /dev/null +++ b/anomalib/engine/__init__.py @@ -0,0 +1,8 @@ +"""Anomalib engine.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .engine import Engine + +__all__ = ["Engine"] diff --git a/anomalib/engine/engine.py b/anomalib/engine/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..6dbaa15a10f3f5fb9478a4a1af0082ee95c6018f --- /dev/null +++ b/anomalib/engine/engine.py @@ -0,0 +1,955 @@ +"""Implements custom trainer for Anomalib.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import torch +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import Logger +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms.v2 import Transform + +from anomalib import LearningType, TaskType +from anomalib.callbacks.checkpoint import ModelCheckpoint +from anomalib.callbacks.metrics import _MetricsCallback +from anomalib.callbacks.normalization import get_normalization_callback +from anomalib.callbacks.normalization.base import NormalizationCallback +from anomalib.callbacks.post_processor import _PostProcessorCallback +from anomalib.callbacks.thresholding import _ThresholdCallback +from anomalib.callbacks.timer import TimerCallback +from anomalib.callbacks.visualizer import _VisualizationCallback +from anomalib.data import AnomalibDataModule, AnomalibDataset, PredictDataset +from anomalib.deploy import ExportType +from anomalib.models import AnomalyModule +from anomalib.utils.normalization import NormalizationMethod +from anomalib.utils.path import create_versioned_dir +from anomalib.utils.types import NORMALIZATION, THRESHOLD +from anomalib.utils.visualization import ImageVisualizer + +logger = logging.getLogger(__name__) + + +class UnassignedError(Exception): + """Unassigned error.""" + + +class _TrainerArgumentsCache: + """Cache arguments. + + Since the Engine class accepts PyTorch Lightning Trainer arguments, we store these arguments using this class + before the trainer is instantiated. + + Args: + (**kwargs): Trainer arguments that are cached + + Example: + >>> conf = OmegaConf.load("config.yaml") + >>> cache = _TrainerArgumentsCache(**conf.trainer) + >>> cache.args + { + ... + 'max_epochs': 100, + 'val_check_interval': 0 + } + >>> model = Padim(layers=["layer1", "layer2", "layer3"], input_size=(256, 256), backbone="resnet18") + >>> cache.update(model) + Overriding max_epochs from 100 with 1 for Padim + Overriding val_check_interval from 0 with 1.0 for Padim + >>> cache.args + { + ... + 'max_epochs': 1, + 'val_check_interval': 1.0 + } + """ + + def __init__(self, **kwargs) -> None: + self._cached_args = {**kwargs} + + def update(self, model: AnomalyModule) -> None: + """Replace cached arguments with arguments retrieved from the model. + + Args: + model (AnomalyModule): The model used for training + """ + for key, value in model.trainer_arguments.items(): + if key in self._cached_args and self._cached_args[key] != value: + logger.info( + f"Overriding {key} from {self._cached_args[key]} with {value} for {model.__class__.__name__}", + ) + self._cached_args[key] = value + + def requires_update(self, model: AnomalyModule) -> bool: + return any(self._cached_args.get(key, None) != value for key, value in model.trainer_arguments.items()) + + @property + def args(self) -> dict[str, Any]: + return self._cached_args + + +class Engine: + """Anomalib Engine. + + .. note:: + + Refer to PyTorch Lightning's Trainer for a list of parameters for + details on other Trainer parameters. + + Args: + callbacks (list[Callback]): Add a callback or list of callbacks. + normalization (NORMALIZATION, optional): Normalization method. + Defaults to NormalizationMethod.MIN_MAX. + threshold (THRESHOLD): + Thresholding method. Defaults to "F1AdaptiveThreshold". + task (TaskType, optional): Task type. Defaults to TaskType.SEGMENTATION. + image_metrics (list[str] | str | dict[str, dict[str, Any]] | None, optional): Image metrics to be used for + evaluation. Defaults to None. + pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None, optional): Pixel metrics to be used for + evaluation. Defaults to None. + default_root_dir (str, optional): Default root directory for the trainer. + The results will be saved in this directory. + Defaults to ``results``. + **kwargs: PyTorch Lightning Trainer arguments. + """ + + def __init__( + self, + callbacks: list[Callback] | None = None, + normalization: NORMALIZATION = NormalizationMethod.MIN_MAX, + threshold: THRESHOLD = "F1AdaptiveThreshold", + task: TaskType | str = TaskType.SEGMENTATION, + image_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, + pixel_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None, + logger: Logger | Iterable[Logger] | bool | None = None, + default_root_dir: str | Path = "results", + **kwargs, + ) -> None: + # TODO(ashwinvaidya17): Add model argument to engine constructor + # https://github.com/openvinotoolkit/anomalib/issues/1639 + if callbacks is None: + callbacks = [] + + # Cache the Lightning Trainer arguments. + logger = False if logger is None else logger + self._cache = _TrainerArgumentsCache( + callbacks=[*callbacks], + logger=logger, + default_root_dir=Path(default_root_dir), + **kwargs, + ) + + self.normalization = normalization + self.threshold = threshold + self.task = TaskType(task) + self.image_metric_names = image_metrics if image_metrics else ["AUROC", "F1Score"] + + # pixel metrics are only used for segmentation tasks. + self.pixel_metric_names = None + if self.task == TaskType.SEGMENTATION: + self.pixel_metric_names = pixel_metrics if pixel_metrics is not None else ["AUROC", "F1Score"] + + self._trainer: Trainer | None = None + + @property + def trainer(self) -> Trainer: + """Property to get the trainer. + + Raises: + UnassignedError: When the trainer is not assigned yet. + + Returns: + Trainer: Lightning Trainer. + """ + if not self._trainer: + msg = "``self.trainer`` is not assigned yet." + raise UnassignedError(msg) + return self._trainer + + @property + def model(self) -> AnomalyModule: + """Property to get the model. + + Raises: + UnassignedError: When the model is not assigned yet. + + Returns: + AnomalyModule: Anomaly model. + """ + if not self.trainer.model: + msg = "Trainer does not have a model assigned yet." + raise UnassignedError(msg) + return self.trainer.lightning_module + + @property + def normalization_callback(self) -> NormalizationCallback | None: + """The ``NormalizationCallback`` callback in the trainer.callbacks list, or ``None`` if it doesn't exist. + + Returns: + NormalizationCallback | None: Normalization callback, if available. + + Raises: + ValueError: If there are multiple normalization callbacks. + """ + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, NormalizationCallback)] + if len(callbacks) > 1: + msg = ( + f"Trainer can only have one normalization callback but multiple found: {callbacks}. " + "Please check your configuration. Exiting to avoid unexpected behavior." + ) + raise ValueError(msg) + return callbacks[0] if len(callbacks) > 0 else None + + @property + def threshold_callback(self) -> _ThresholdCallback | None: + """The ``ThresholdCallback`` callback in the trainer.callbacks list, or ``None`` if it doesn't exist. + + Returns: + _ThresholdCallback | None: Threshold callback, if available. + + Raises: + ValueError: If there are multiple threshold callbacks. + """ + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, _ThresholdCallback)] + if len(callbacks) > 1: + msg = ( + f"Trainer can only have one thresholding callback but multiple found: {callbacks}. " + "Please check your configuration. Exiting to avoid unexpected behavior." + ) + raise ValueError(msg) + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callback(self) -> ModelCheckpoint | None: + """The ``ModelCheckpoint`` callback in the trainer.callbacks list, or ``None`` if it doesn't exist. + + Returns: + ModelCheckpoint | None: ModelCheckpoint callback, if available. + """ + if self._trainer is None: + return None + return self.trainer.checkpoint_callback + + @property + def best_model_path(self) -> str | None: + """The path to the best model checkpoint. + + Returns: + str: Path to the best model checkpoint. + """ + if self.checkpoint_callback is None: + return None + return self.checkpoint_callback.best_model_path + + def _setup_workspace( + self, + model: AnomalyModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + dataset: AnomalibDataset | None = None, + versioned_dir: bool = False, + ) -> None: + """Setup the workspace for the model. + + This method sets up the default root directory for the model based on + the model name, dataset name, and category. Model checkpoints, logs, and + other artifacts will be saved in this directory. + + Args: + model (AnomalyModule): Input model. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. + Defaults to ``None``. + val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. + Defaults to ``None``. + test_dataloaders (EVAL_DATALOADERS | None, optional): Test dataloaders. + Defaults to ``None``. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Defaults to ``None``. + dataset (AnomalibDataset | None, optional): Anomalib dataset. + Defaults to ``None``. + versioned_dir (bool, optional): Whether to create a versioned directory. + Defaults to ``True``. + + Raises: + TypeError: If the dataloader type is unknown. + """ + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # 1. Get the dataset name and category from the dataloaders, datamodule, or dataset. + dataset_name: str = "" + category: str | None = None + + # Check datamodule and dataset directly + if datamodule is not None: + dataset_name = datamodule.name + category = datamodule.category + elif dataset is not None: + dataset_name = dataset.name + category = dataset.category + + # Check dataloaders if dataset_name and category are not set + dataloaders = [train_dataloaders, val_dataloaders, test_dataloaders] + if not dataset_name or category is None: + for dataloader in dataloaders: + if dataloader is not None: + if hasattr(dataloader, "train_data"): + dataset_name = getattr(dataloader.train_data, "name", "") + category = getattr(dataloader.train_data, "category", "") + break + if dataset_name and category is not None: + break + + # Check if category is None and set it to empty string + category = category if category is not None else "" + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # 2. Update the default root directory + root_dir = Path(self._cache.args["default_root_dir"]) / model.name / dataset_name / category + self._cache.args["default_root_dir"] = create_versioned_dir(root_dir) if versioned_dir else root_dir / "latest" + + def _setup_trainer(self, model: AnomalyModule) -> None: + """Instantiate the trainer based on the model parameters.""" + # Check if the cache requires an update + if self._cache.requires_update(model): + self._cache.update(model) + + # Setup anomalib callbacks to be used with the trainer + self._setup_anomalib_callbacks() + + # Temporarily set devices to 1 to avoid issues with multiple processes + self._cache.args["devices"] = 1 + + # Instantiate the trainer if it is not already instantiated + if self._trainer is None: + self._trainer = Trainer(**self._cache.args) + + def _setup_dataset_task( + self, + *dataloaders: EVAL_DATALOADERS | TRAIN_DATALOADERS | AnomalibDataModule | None, + ) -> None: + """Override the dataloader task with the task passed to the Engine. + + Args: + dataloaders (TRAIN_DATALOADERS | EVAL_DATALOADERS): Dataloaders to be used for training or evaluation. + """ + for dataloader in dataloaders: + if dataloader is not None and isinstance(dataloader, AnomalibDataModule): + for attribute in ("train_data", "val_data", "test_data"): + if hasattr(dataloader, attribute): + data: AnomalibDataset = getattr(dataloader, attribute) + if data.task != self.task: + logger.info( + f"Overriding task from {data.task} with {self.task} for {dataloader.__class__}", + ) + data.task = self.task + + @staticmethod + def _setup_transform( + model: AnomalyModule, + datamodule: AnomalibDataModule | None = None, + dataloaders: EVAL_DATALOADERS | TRAIN_DATALOADERS | None = None, + ckpt_path: Path | str | None = None, + ) -> None: + """Implements the logic for setting the transform at the start of each run. + + Any transform passed explicitly to the datamodule takes precedence. Otherwise, if a checkpoint path is provided, + we can load the transform from the checkpoint. If no transform is provided, we use the default transform from + the model. + + Args: + model (AnomalyModule): The model to assign the transform to. + datamodule (AnomalibDataModule | None): The datamodule to assign the transform from. + defaults to ``None``. + dataloaders (EVAL_DATALOADERS | TRAIN_DATALOADERS | None): Dataloaders to assign the transform to. + defaults to ``None``. + ckpt_path (str): The path to the checkpoint. + defaults to ``None``. + + Returns: + Transform: The transform loaded from the checkpoint. + """ + if isinstance(dataloaders, DataLoader): + dataloaders = [dataloaders] + + # get transform + if datamodule and datamodule.transform: + # a transform passed explicitly to the datamodule takes precedence + transform = datamodule.transform + elif dataloaders and any(getattr(dl.dataset, "transform", None) for dl in dataloaders): + # if dataloaders are provided, we use the transform from the first dataloader that has a transform + transform = next(dl.dataset.transform for dl in dataloaders if getattr(dl.dataset, "transform", None)) + elif ckpt_path is not None: + # if a checkpoint path is provided, we can load the transform from the checkpoint + checkpoint = torch.load(ckpt_path, map_location=model.device) + transform = checkpoint["transform"] + elif model.transform is None: + # if no transform is provided, we use the default transform from the model + image_size = datamodule.image_size if datamodule else None + transform = model.configure_transforms(image_size) + else: + transform = model.transform + + # update transform in model + model.set_transform(transform) + # The dataloaders don't have access to the trainer and/or model, so we need to set the transforms manually + if dataloaders: + for dataloader in dataloaders: + if not getattr(dataloader.dataset, "transform", None): + dataloader.dataset.transform = transform + + def _setup_anomalib_callbacks(self) -> None: + """Set up callbacks for the trainer.""" + _callbacks: list[Callback] = [] + + # Add ModelCheckpoint if it is not in the callbacks list. + has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"]) + if has_checkpoint_callback is False: + _callbacks.append( + ModelCheckpoint( + dirpath=self._cache.args["default_root_dir"] / "weights" / "lightning", + filename="model", + auto_insert_metric_name=False, + ), + ) + + # Add the post-processor callbacks. + _callbacks.append(_PostProcessorCallback()) + + # Add the the normalization callback. + normalization_callback = get_normalization_callback(self.normalization) + if normalization_callback is not None: + _callbacks.append(normalization_callback) + + # Add the thresholding and metrics callbacks. + _callbacks.append(_ThresholdCallback(self.threshold)) + _callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names)) + + _callbacks.append( + _VisualizationCallback( + visualizers=ImageVisualizer(task=self.task), + save=True, + root=self._cache.args["default_root_dir"] / "images", + ), + ) + + _callbacks.append(TimerCallback()) + + # Combine the callbacks, and update the trainer callbacks. + self._cache.args["callbacks"] = _callbacks + self._cache.args["callbacks"] + + def _should_run_validation( + self, + model: AnomalyModule, + dataloaders: EVAL_DATALOADERS | None, + datamodule: AnomalibDataModule | None, + ckpt_path: str | Path | None, + ) -> bool: + """Check if we need to run validation to collect normalization statistics and thresholds. + + If a checkpoint path is provided, we don't need to run validation because we can load the model from the + checkpoint and use the normalization metrics and thresholds from the checkpoint. + + We need to run validation if the model is configured with normalization enabled, but no normalization metrics + have been collected yet. Similarly, we need to run validation if the model is configured with adaptive + thresholding enabled, but no thresholds have been computed yet. + + We can only run validation if we have validation data available, so we check if the dataloaders or datamodule + are available. If neither is available, we can't run validation. + + Args: + model (AnomalyModule): Model passed to the entrypoint. + dataloaders (EVAL_DATALOADERS | None): Dataloaders passed to the entrypoint. + datamodule (AnomalibDataModule | None): Lightning datamodule passed to the entrypoint. + ckpt_path (str | Path | None): Checkpoint path passed to the entrypoint. + + Returns: + bool: Whether it is needed to run a validation sequence. + """ + # validation before predict is only necessary for zero-/few-shot models + if model.learning_type not in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: + return False + # check if a checkpoint path is provided + if ckpt_path is not None: + return False + # check if the model needs to be validated + needs_normalization = self.normalization_callback is not None and not hasattr(model, "normalization_metrics") + needs_thresholding = self.threshold_callback is not None and not hasattr(model, "image_threshold") + # check if the model can be validated (i.e. validation data is available) + return (needs_normalization or needs_thresholding) and (dataloaders is not None or datamodule is not None) + + def fit( + self, + model: AnomalyModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + ckpt_path: str | Path | None = None, + ) -> None: + """Fit the model using the trainer. + + Args: + model (AnomalyModule): Model to be trained. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. + Defaults to None. + val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. + Defaults to None. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + If provided, dataloaders will be instantiated from this. + Defaults to None. + ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path. + Defaults to None. + + CLI Usage: + 1. you can pick a model, and you can run through the MVTec dataset. + ```python + anomalib fit --model anomalib.models.Padim + ``` + 2. Of course, you can override the various values with commands. + ```python + anomalib fit --model anomalib.models.Padim --data --trainer.max_epochs 3 + ``` + 4. If you have a ready configuration file, run it like this. + ```python + anomalib fit --config + ``` + """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace( + model=model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + versioned_dir=True, + ) + self._setup_trainer(model) + self._setup_dataset_task(train_dataloaders, val_dataloaders, datamodule) + self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path) + if model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: + # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding + self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path) + else: + self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) + + def validate( + self, + model: AnomalyModule | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + ckpt_path: str | Path | None = None, + verbose: bool = True, + datamodule: AnomalibDataModule | None = None, + ) -> _EVALUATE_OUTPUT | None: + """Validate the model using the trainer. + + Args: + model (AnomalyModule | None, optional): Model to be validated. + Defaults to None. + dataloaders (EVAL_DATALOADERS | None, optional): Dataloaders to be used for + validation. + Defaults to None. + ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path. + Defaults to None. + verbose (bool, optional): Boolean to print the validation results. + Defaults to True. + datamodule (AnomalibDataModule | None, optional): A :class:`~lightning.pytorch.core.datamodule + AnomalibDataModule` that defines the + :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. + Defaults to None. + + Returns: + _EVALUATE_OUTPUT | None: Validation results. + + CLI Usage: + 1. you can pick a model. + ```python + anomalib validate --model anomalib.models.Padim + ``` + 2. Of course, you can override the various values with commands. + ```python + anomalib validate --model anomalib.models.Padim --data + ``` + 4. If you have a ready configuration file, run it like this. + ```python + anomalib validate --config + ``` + """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + if model: + self._setup_trainer(model) + self._setup_dataset_task(dataloaders) + self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path) + return self.trainer.validate(model, dataloaders, ckpt_path, verbose, datamodule) + + def test( + self, + model: AnomalyModule | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + ckpt_path: str | Path | None = None, + verbose: bool = True, + datamodule: AnomalibDataModule | None = None, + ) -> _EVALUATE_OUTPUT: + """Test the model using the trainer. + + Sets up the trainer and the dataset task if not already set up. Then validates the model if needed and + finally tests the model. + + Args: + model (AnomalyModule | None, optional): + The model to be tested. + Defaults to None. + dataloaders (EVAL_DATALOADERS | None, optional): + An iterable or collection of iterables specifying test samples. + Defaults to None. + ckpt_path (str | None, optional): + Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded + if a checkpoint callback is configured. + Defaults to None. + verbose (bool, optional): + If True, prints the test results. + Defaults to True. + datamodule (AnomalibDataModule | None, optional): + A :class:`~lightning.pytorch.core.datamodule.AnomalibDataModule` that defines + the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. + Defaults to None. + + Returns: + _EVALUATE_OUTPUT: A List of dictionaries containing the test results. 1 dict per dataloader. + + Examples: + # fit and test a one-class model + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Padim() + >>> model.learning_type + + + >>> engine = Engine() + >>> engine.fit(model, datamodule=datamodule) + >>> engine.test(model, datamodule=datamodule) + + # Test a zero-shot model + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec(image_size=240, normalization="clip") + >>> model = Padim() + >>> model.learning_type + + + >>> engine = Engine() + >>> engine.test(model, datamodule=datamodule) + + CLI Usage: + 1. you can pick a model. + ```python + anomalib test --model anomalib.models.Padim + ``` + 2. Of course, you can override the various values with commands. + ```python + anomalib test --model anomalib.models.Padim --data + ``` + 4. If you have a ready configuration file, run it like this. + ```python + anomalib test --config + ``` + """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace(model=model or self.model, datamodule=datamodule, test_dataloaders=dataloaders) + + if model: + self._setup_trainer(model) + elif not self.model: + msg = "`Engine.test()` requires an `AnomalyModule` when it hasn't been passed in a previous run." + raise RuntimeError(msg) + + self._setup_dataset_task(dataloaders) + self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path) + if self._should_run_validation(model or self.model, dataloaders, datamodule, ckpt_path): + logger.info("Running validation before testing to collect normalization metrics and/or thresholds.") + self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule) + return self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule) + + def predict( + self, + model: AnomalyModule | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + dataset: Dataset | PredictDataset | None = None, + return_predictions: bool | None = None, + ckpt_path: str | Path | None = None, + data_path: str | Path | None = None, + ) -> _PREDICT_OUTPUT | None: + """Predict using the model using the trainer. + + Sets up the trainer and the dataset task if not already set up. Then validates the model if needed and a + validation dataloader is available. Finally, predicts using the model. + + Args: + model (AnomalyModule | None, optional): + Model to be used for prediction. + Defaults to None. + dataloaders (EVAL_DATALOADERS | None, optional): + An iterable or collection of iterables specifying predict samples. + Defaults to None. + datamodule (AnomalibDataModule | None, optional): + A :class:`~lightning.pytorch.core.datamodule.AnomalibDataModule` that defines + the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. + The datamodule can also be a dataset that will be wrapped in a torch Dataloader. + Defaults to None. + dataset (Dataset | PredictDataset | None, optional): + A :class:`~torch.utils.data.Dataset` or :class:`~anomalib.data.PredictDataset` that will be used + to create a dataloader. Defaults to None. + return_predictions (bool | None, optional): + Whether to return predictions. + ``True`` by default except when an accelerator that spawns processes is used (not supported). + Defaults to None. + ckpt_path (str | None, optional): + Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded + if a checkpoint callback is configured. + Defaults to None. + data_path (str | Path | None): + Path to the image or folder containing images to generate predictions for. + Defaults to None. + + Returns: + _PREDICT_OUTPUT | None: Predictions. + + CLI Usage: + 1. you can pick a model. + ```python + anomalib predict --model anomalib.models.Padim + anomalib predict --model Padim \ + --data datasets/MVTec/bottle/test/broken_large + ``` + 2. Of course, you can override the various values with commands. + ```python + anomalib predict --model anomalib.models.Padim \ + --data + ``` + 4. If you have a ready configuration file, run it like this. + ```python + anomalib predict --config --return_predictions + ``` + 5. You can also point to a folder with image or a single image instead of passing a dataset. + ```python + anomalib predict --model Padim --data --ckpt_path + ``` + """ + if not (model or self.model): + msg = "`Engine.predict()` requires an `AnomalyModule` when it hasn't been passed in a previous run." + raise ValueError(msg) + + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace(model=model or self.model, datamodule=datamodule, test_dataloaders=dataloaders) + + if model: + self._setup_trainer(model) + + if not ckpt_path: + logger.warning("ckpt_path is not provided. Model weights will not be loaded.") + + # Collect dataloaders + if dataloaders is None: + dataloaders = [] + elif isinstance(dataloaders, DataLoader): + dataloaders = [dataloaders] + elif not isinstance(dataloaders, list): + msg = f"Unknown type for dataloaders {type(dataloaders)}" + raise TypeError(msg) + if dataset is not None: + dataloaders.append(DataLoader(dataset)) + if data_path is not None: + dataloaders.append(DataLoader(PredictDataset(data_path))) + dataloaders = dataloaders or None + + self._setup_dataset_task(dataloaders, datamodule) + self._setup_transform(model or self.model, datamodule=datamodule, dataloaders=dataloaders, ckpt_path=ckpt_path) + + if self._should_run_validation(model or self.model, None, datamodule, ckpt_path): + logger.info("Running validation before predicting to collect normalization metrics and/or thresholds.") + self.trainer.validate( + model, + dataloaders=None, + ckpt_path=None, + verbose=False, + datamodule=datamodule, + ) + + return self.trainer.predict(model, dataloaders, datamodule, return_predictions, ckpt_path) + + def train( + self, + model: AnomalyModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + ckpt_path: str | Path | None = None, + ) -> _EVALUATE_OUTPUT: + """Fits the model and then calls test on it. + + Args: + model (AnomalyModule): Model to be trained. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. + Defaults to None. + val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. + Defaults to None. + test_dataloaders (EVAL_DATALOADERS | None, optional): Test dataloaders. + Defaults to None. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + If provided, dataloaders will be instantiated from this. + Defaults to None. + ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path. + Defaults to None. + + CLI Usage: + 1. you can pick a model, and you can run through the MVTec dataset. + ```python + anomalib train --model anomalib.models.Padim --data MVTec + ``` + 2. Of course, you can override the various values with commands. + ```python + anomalib train --model anomalib.models.Padim --data --trainer.max_epochs 3 + ``` + 4. If you have a ready configuration file, run it like this. + ```python + anomalib train --config + ``` + """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + self._setup_workspace( + model, + train_dataloaders, + val_dataloaders, + test_dataloaders, + datamodule, + versioned_dir=True, + ) + self._setup_trainer(model) + self._setup_dataset_task( + train_dataloaders, + val_dataloaders, + test_dataloaders, + datamodule, + ) + self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path) + if model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: + # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding + self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule) + else: + self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) + self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule) + + def export( + self, + model: AnomalyModule, + export_type: ExportType | str, + export_root: str | Path | None = None, + input_size: tuple[int, int] | None = None, + transform: Transform | None = None, + ov_args: dict[str, Any] | None = None, + ckpt_path: str | Path | None = None, + ) -> Path | None: + r"""Export the model in PyTorch, ONNX or OpenVINO format. + + Args: + model (AnomalyModule): Trained model. + export_type (ExportType): Export type. + export_root (str | Path | None, optional): Path to the output directory. If it is not set, the model is + exported to trainer.default_root_dir. Defaults to None. + input_size (tuple[int, int] | None, optional): A statis input shape for the model, which is exported to ONNX + and OpenVINO format. Defaults to None. + transform (Transform | None, optional): Input transform to include in the exported model. If not provided, + the engine will try to use the default transform from the model. + Defaults to ``None``. + ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. + Defaults to None. + ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path. + + Returns: + Path: Path to the exported model. + + Raises: + ValueError: If Dataset, Datamodule, and transform are not provided. + TypeError: If path to the transform file is not a string or Path. + + CLI Usage: + 1. To export as a torch ``.pt`` file you can run the following command. + ```python + anomalib export --model Padim --export_mode torch --ckpt_path + ``` + 2. To export as an ONNX ``.onnx`` file you can run the following command. + ```python + anomalib export --model Padim --export_mode onnx --ckpt_path \ + --input_size "[256,256]" + ``` + 3. To export as an OpenVINO ``.xml`` and ``.bin`` file you can run the following command. + ```python + anomalib export --model Padim --export_mode openvino --ckpt_path \ + --input_size "[256,256]" + ``` + 4. You can also overrride OpenVINO model optimizer by adding the ``--ov_args.`` arguments. + ```python + anomalib export --model Padim --export_mode openvino --ckpt_path \ + --input_size "[256,256]" --ov_args.compress_to_fp16 False + ``` + """ + export_type = ExportType(export_type) + self._setup_trainer(model) + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + model = model.__class__.load_from_checkpoint(ckpt_path) + + if export_root is None: + export_root = Path(self.trainer.default_root_dir) + + exported_model_path: Path | None = None + if export_type == ExportType.TORCH: + exported_model_path = model.to_torch( + export_root=export_root, + transform=transform, + task=self.task, + ) + elif export_type == ExportType.ONNX: + exported_model_path = model.to_onnx( + export_root=export_root, + input_size=input_size, + transform=transform, + task=self.task, + ) + elif export_type == ExportType.OPENVINO: + exported_model_path = model.to_openvino( + export_root=export_root, + input_size=input_size, + transform=transform, + task=self.task, + ov_args=ov_args, + ) + else: + logging.error(f"Export type {export_type} is not supported yet.") + + if exported_model_path: + logging.info(f"Exported model to {exported_model_path}") + return exported_model_path diff --git a/anomalib/loggers/__init__.py b/anomalib/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9da16b8138a4df55937db576df733a450ef6158f --- /dev/null +++ b/anomalib/loggers/__init__.py @@ -0,0 +1,54 @@ +"""Load PyTorch Lightning Loggers.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +from rich.logging import RichHandler + +__all__ = [ + "configure_logger", + "get_experiment_logger", +] + +try: + from .comet import AnomalibCometLogger # noqa: F401 + from .mlflow import AnomalibMLFlowLogger # noqa: F401 + from .tensorboard import AnomalibTensorBoardLogger # noqa: F401 + from .wandb import AnomalibWandbLogger # noqa: F401 + + __all__.extend( + [ + "AnomalibCometLogger", + "AnomalibTensorBoardLogger", + "AnomalibWandbLogger", + "AnomalibMLFlowLogger", + ], + ) +except ImportError: + print("To use any logger install it using `anomalib install -v`") + + +def configure_logger(level: int | str = logging.INFO) -> None: + """Get console logger by name. + + Args: + level (int | str, optional): Logger Level. Defaults to logging.INFO. + + Returns: + Logger: The expected logger. + """ + if isinstance(level, str): + level = logging.getLevelName(level) + + format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + logging.basicConfig(format=format_string, level=level) + logging.getLogger().addHandler(RichHandler(rich_tracebacks=True)) + + # Set Pytorch Lightning logs to have a the consistent formatting with anomalib. + for handler in logging.getLogger("lightning.pytorch").handlers: + handler.setFormatter(logging.Formatter(format_string)) + handler.setLevel(level) + logging.getLogger("lightning.pytorch").addHandler(RichHandler(rich_tracebacks=True)) diff --git a/anomalib/loggers/base.py b/anomalib/loggers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c043510fc581360ff7d3bf0160a2748b6b5138 --- /dev/null +++ b/anomalib/loggers/base.py @@ -0,0 +1,19 @@ +"""Base logger for image logging consistency across all loggers used in anomalib.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from abc import abstractmethod + +import numpy as np +from matplotlib.figure import Figure + + +class ImageLoggerBase: + """Adds a common interface for logging the images.""" + + @abstractmethod + def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: + """Interface to log images in the respective loggers.""" + raise NotImplementedError diff --git a/anomalib/loggers/comet.py b/anomalib/loggers/comet.py new file mode 100644 index 0000000000000000000000000000000000000000..913977187a527d1e31802065819e6b9a65cf0b61 --- /dev/null +++ b/anomalib/loggers/comet.py @@ -0,0 +1,136 @@ +"""comet logger with add image interface.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +from matplotlib.figure import Figure + +try: + from lightning.pytorch.loggers.comet import CometLogger +except ModuleNotFoundError: + print("To use comet logger install it using `pip install comet-ml`") +from lightning.pytorch.utilities import rank_zero_only + +from .base import ImageLoggerBase + + +class AnomalibCometLogger(ImageLoggerBase, CometLogger): + """Logger for comet. + + Adds interface for ``add_image`` in the logger rather than calling the + experiment object. + + .. note:: + Same as the CometLogger provided by PyTorch Lightning and the doc string + is reproduced below. + + Track your parameters, metrics, source code and more using + `Comet `_. + + Install it with pip: + + .. code-block:: bash + + pip install comet-ml + + Comet requires either an API Key (online mode) or a local directory path + (offline mode). + + Args: + api_key: Required in online mode. API key, found on Comet.ml. If not + given, this will be loaded from the environment variable + COMET_API_KEY or ~/.comet.config if either exists. + Defaults to ``None``. + save_dir: Required in offline mode. The path for the directory to save + local comet logs. If given, this also sets the directory for saving + checkpoints. + Defaults to ``None``. + project_name: Optional. Send your experiment to a specific project. + Otherwise will be sent to Uncategorized Experiments. + If the project name does not already exist, Comet.ml will create a + new project. + Defaults to ``None``. + rest_api_key: Optional. Rest API key found in Comet.ml settings. + This is used to determine version number + Defaults to ``None``. + experiment_name: Optional. String representing the name for this + particular experiment on Comet.ml. + Defaults to ``None``. + experiment_key: Optional. If set, restores from existing experiment. + Defaults to ``None``. + offline: If api_key and save_dir are both given, this determines whether + the experiment will be in online or offline mode. This is useful if + you use save_dir to control the checkpoints directory and have a + ~/.comet.config file but still want to run offline experiments. + Defaults to ``None``. + prefix: A string to put at the beginning of metric keys. + Defaults to ``""``. + kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + :class:`CometExperiment` can be passed as keyword arguments in this + logger. + + Raises: + ModuleNotFoundError: + If required Comet package is not installed on the device. + MisconfigurationException: + If neither ``api_key`` nor ``save_dir`` are passed as arguments. + + Example: + >>> from anomalib.loggers import AnomalibCometLogger + >>> from anomalib.engine import Engine + ... + >>> comet_logger = AnomalibCometLogger() + >>> engine = Engine(logger=comet_logger) + + See Also: + - `Comet Documentation `__ + """ + + def __init__( + self, + api_key: str | None = None, + save_dir: str | None = None, + project_name: str | None = None, + rest_api_key: str | None = None, + experiment_name: str | None = None, + experiment_key: str | None = None, + offline: bool = False, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__( + api_key=api_key, + save_dir=save_dir, + project_name=project_name, + rest_api_key=rest_api_key, + experiment_name=experiment_name, + experiment_key=experiment_key, + offline=offline, + prefix=prefix, + **kwargs, + ) + self.experiment.log_other("Created from", "Anomalib") + + @rank_zero_only + def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: + """Interface to add image to comet logger. + + Args: + image (np.ndarray | Figure): Image to log. + name (str | None): The tag of the image + Defaults to ``None``. + kwargs: Accepts only `global_step` (int). The step at which to log the image. + """ + if "global_step" not in kwargs: + msg = "`global_step` is required for comet logger" + raise ValueError(msg) + + global_step = kwargs["global_step"] + # Need to call different functions of `Experiment` for Figure vs np.ndarray + + if isinstance(image, Figure): + self.experiment.log_figure(figure_name=name, figure=image, step=global_step) + else: + self.experiment.log_image(name=name, image_data=image, step=global_step) diff --git a/anomalib/loggers/mlflow.py b/anomalib/loggers/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc383507235ffcec5074a8a53b35938102dbe4f --- /dev/null +++ b/anomalib/loggers/mlflow.py @@ -0,0 +1,103 @@ +"""MLFlow logger with add image interface.""" + +from typing import Literal + +import numpy as np +from lightning.pytorch.loggers.mlflow import MLFlowLogger +from lightning.pytorch.utilities import rank_zero_only +from matplotlib.figure import Figure + +from anomalib.utils.exceptions.imports import try_import + +from .base import ImageLoggerBase + +try_import("mlflow") + + +class AnomalibMLFlowLogger(ImageLoggerBase, MLFlowLogger): + """Logger for MLFlow. + + Adds interface for ``add_image`` in the logger rather than calling the + experiment object. + + .. note:: + Same as the MLFlowLogger provided by PyTorch Lightning and the doc string is reproduced below. + + Track your parameters, metrics, source code and more using + `MLFlow `_. + + Install it with pip: + + .. code-block:: bash + + pip install mlflow + + Args: + experiment_name: The name of the experiment. + run_name: Name of the new run. + The `run_name` is internally stored as a ``mlflow.runName`` tag. + If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. + tracking_uri: Address of local or remote tracking server. + If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls + back to `file:`. + save_dir: A path to a local directory where the MLflow runs get saved. + Defaults to `./mlruns` if `tracking_uri` is not provided. + Has no effect if `tracking_uri` is provided. + log_model: Log checkpoints created by `ModelCheckpoint` as MLFlow artifacts. + + - if ``log_model == 'all'``, checkpoints are logged during training. + - if ``log_model == True``, checkpoints are logged at the end of training, \ + except when `save_top_k == -1` which also logs every checkpoint during training. + - if ``log_model == False`` (default), no checkpoint is logged. + + prefix: A string to put at the beginning of metric keys. Defaults to ``''``. + kwargs: Additional arguments like `tags`, `artifact_location` etc. used by + `MLFlowExperiment` can be passed as keyword arguments in this logger. + + Example: + >>> from anomalib.loggers import AnomalibMLFlowLogger + >>> from anomalib.engine import Engine + ... + >>> mlflow_logger = AnomalibMLFlowLogger() + >>> engine = Engine(logger=mlflow_logger) + + See Also: + - `MLFlow Documentation `_. + """ + + def __init__( + self, + experiment_name: str | None = "anomalib_logs", + run_name: str | None = None, + tracking_uri: str | None = None, + save_dir: str | None = "./mlruns", + log_model: Literal[True, False, "all"] | None = False, + prefix: str | None = "", + **kwargs, + ) -> None: + super().__init__( + experiment_name=experiment_name, + run_name=run_name, + tracking_uri=tracking_uri, + save_dir=save_dir, + log_model=log_model, + prefix=prefix, + **kwargs, + ) + + @rank_zero_only + def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: + """Interface to log images in the mlflow loggers. + + Args: + image (np.ndarray | Figure): Image to log. + name (str | None): The tag of the image defaults to ``None``. + kwargs: Additional keyword arguments that are only used if `image` is of type Figure. + These arguments are passed directly to the method that saves the figure. + If `image` is a NumPy array, `kwargs` has no effect. + """ + # Need to call different functions of `Experiment` for Figure vs np.ndarray + if isinstance(image, Figure): + self.experiment.log_figure(run_id=self.run_id, figure=image, artifact_file=name, **kwargs) + else: + self.experiment.log_image(run_id=self.run_id, image=image, artifact_file=name) diff --git a/anomalib/loggers/tensorboard.py b/anomalib/loggers/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..77becd4f04fa608d5a6933d8b69bca4a3be3d1cd --- /dev/null +++ b/anomalib/loggers/tensorboard.py @@ -0,0 +1,105 @@ +"""Tensorboard logger with add image interface.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path + +import numpy as np +from matplotlib.figure import Figure + +try: + from lightning.pytorch.loggers.tensorboard import TensorBoardLogger +except ModuleNotFoundError: + print("To use tensorboard logger install it using `pip install tensorboard`") +from lightning.pytorch.utilities import rank_zero_only + +from .base import ImageLoggerBase + + +class AnomalibTensorBoardLogger(ImageLoggerBase, TensorBoardLogger): + """Logger for tensorboard. + + Adds interface for `add_image` in the logger rather than calling the experiment object. + + .. note:: + Same as the Tensorboard Logger provided by PyTorch Lightning and the doc string is reproduced below. + + Logs are saved to + ``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes + preinstalled. + + Example: + >>> from anomalib.engine import Engine + >>> from anomalib.loggers import AnomalibTensorBoardLogger + ... + >>> logger = AnomalibTensorBoardLogger("tb_logs", name="my_model") + >>> engine = Engine(logger=logger) + + Args: + save_dir (str): Save directory + name (str | None): Experiment name. Defaults to ``'default'``. + If it is the empty string then no per-experiment subdirectory is used. + Default: ``'default'``. + version (int | str | None): Experiment version. If version is not + specified the logger inspects the save directory for existing + versions, then automatically assigns the next available version. + If it is a string then it is used as the run-specific subdirectory + name, otherwise ``'version_${version}'`` is used. + Defaults to ``None`` + log_graph (bool): Adds the computational graph to tensorboard. This + requires that the user has defined the `self.example_input_array` + attribute in their model. + Defaults to ``False``. + default_hp_metric (bool): Enables a placeholder metric with key + ``hp_metric`` when ``log_hyperparams`` is called without a metric + (otherwise calls to log_hyperparams without a metric are ignored). + Defaults to ``True``. + prefix (str): A string to put at the beginning of metric keys. + Defaults to ``''``. + **kwargs: Additional arguments like `comment`, `filename_suffix`, etc. + used by :class:`SummaryWriter` can be passed as keyword arguments in + this logger. + """ + + def __init__( + self, + save_dir: str, + name: str | None = "default", + version: int | str | None = None, + log_graph: bool = False, + default_hp_metric: bool = True, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__( + save_dir, + name=name, + version=version, + log_graph=log_graph, + default_hp_metric=default_hp_metric, + prefix=prefix, + **kwargs, + ) + Path(save_dir).mkdir(parents=True, exist_ok=True) + + @rank_zero_only + def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: + """Interface to add image to tensorboard logger. + + Args: + image (np.ndarray | Figure): Image to log + name (str | None): The tag of the image + Defaults to ``None``. + kwargs: Accepts only `global_step` (int). The step at which to log the image. + """ + if "global_step" not in kwargs: + msg = "`global_step` is required for tensorboard logger" + raise ValueError(msg) + + # Need to call different functions of `SummaryWriter` for Figure vs np.ndarray + if isinstance(image, Figure): + self.experiment.add_figure(figure=image, tag=name, close=False, **kwargs) + else: + self.experiment.add_image(img_tensor=image, tag=name, dataformats="HWC", **kwargs) diff --git a/anomalib/loggers/wandb.py b/anomalib/loggers/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..0a23c251920605fd21a4cc1cf4669ea8dba7b9e7 --- /dev/null +++ b/anomalib/loggers/wandb.py @@ -0,0 +1,150 @@ +"""wandb logger with add image interface.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Literal + +import numpy as np +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.loggers.wandb import WandbLogger +from lightning.pytorch.utilities import rank_zero_only +from matplotlib.figure import Figure + +from anomalib.utils.exceptions import try_import + +from .base import ImageLoggerBase + +if try_import("wandb"): + import wandb + +if TYPE_CHECKING: + from wandb.sdk.lib import RunDisabled + from wandb.sdk.wandb_run import Run + + +class AnomalibWandbLogger(ImageLoggerBase, WandbLogger): + """Logger for wandb. + + Adds interface for `add_image` in the logger rather than calling the experiment object. + + .. note:: + Same as the wandb Logger provided by PyTorch Lightning and the doc string is reproduced below. + + Log using `Weights and Biases `_. + + Install it with pip: + + .. code-block:: bash + + $ pip install wandb + + Args: + name: Display name for the run. + Defaults to ``None``. + save_dir: Path where data is saved (wandb dir by default). + Defaults to ``None``. + version: Sets the version, mainly used to resume a previous run. + offline: Run offline (data can be streamed later to wandb servers). + Defaults to ``False``. + dir: Alias for save_dir. + id: Sets the version, mainly used to resume a previous run. + Defaults to ``None``. + anonymous: Enables or explicitly disables anonymous logging. + Defaults to ``None``. + version: Same as id. + Defaults to ``None``. + project: The name of the project to which this run will belong. + Defaults to ``None``. + log_model: Save checkpoints in wandb dir to upload on W&B servers. + Defaults to ``False``. + experiment: WandB experiment object. Automatically set when creating a run. + Defaults to ``None``. + prefix: A string to put at the beginning of metric keys. + Defaults to ``''``. + **kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. + + Raises: + ImportError: + If required WandB package is not installed on the device. + MisconfigurationException: + If both ``log_model`` and ``offline``is set to ``True``. + + Example: + >>> from anomalib.loggers import AnomalibWandbLogger + >>> from anomalib.engine import Engine + ... + >>> wandb_logger = AnomalibWandbLogger() + >>> engine = Engine(logger=wandb_logger) + + .. note:: + When logging manually through `wandb.log` or `trainer.logger.experiment.log`, + make sure to use `commit=False` so the logging step does not increase. + + See Also: + - `Tutorial `__ + on how to use W&B with PyTorch Lightning + - `W&B Documentation `__ + + """ + + def __init__( + self, + name: str | None = None, + save_dir: _PATH = ".", + version: str | None = None, + offline: bool = False, + dir: _PATH | None = None, # kept to match wandb init # noqa: A002 + id: str | None = None, # kept to match wandb init # noqa: A002 + anonymous: bool | None = None, + project: str | None = None, + log_model: Literal["all"] | bool = False, + experiment: "Run | RunDisabled | None" = None, + prefix: str = "", + checkpoint_name: str | None = None, + **kwargs, + ) -> None: + super().__init__( + name=name, + save_dir=save_dir, + version=version, + offline=offline, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + experiment=experiment, + prefix=prefix, + checkpoint_name=checkpoint_name, + **kwargs, + ) + self.image_list: list[wandb.Image] = [] # Cache images + + @rank_zero_only + def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: + """Interface to add image to wandb logger. + + Args: + image (np.ndarray | Figure): Image to log + name (str | None): The tag of the image + Defaults to ``None``. + kwargs: Additional arguments to `wandb.Image` + """ + del kwargs # Unused argument. + + image = wandb.Image(image, caption=name) + self.image_list.append(image) + + @rank_zero_only + def save(self) -> None: + """Upload images to wandb server. + + .. note:: + There is a limit on the number of images that can be logged together to the `wandb` server. + """ + super().save() + if len(self.image_list) > 1: + wandb.log({"Predictions": self.image_list}) + self.image_list = [] + self.image_list = [] diff --git a/anomalib/metrics/__init__.py b/anomalib/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a59b7a846b98c7269fb81ab65cbab5658ca16b7 --- /dev/null +++ b/anomalib/metrics/__init__.py @@ -0,0 +1,201 @@ +"""Custom anomaly evaluation metrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from collections.abc import Callable +from typing import Any + +import torchmetrics +from omegaconf import DictConfig, ListConfig + +from .anomaly_score_distribution import AnomalyScoreDistribution +from .aupr import AUPR +from .aupro import AUPRO +from .auroc import AUROC +from .collection import AnomalibMetricCollection +from .f1_max import F1Max +from .f1_score import F1Score +from .min_max import MinMax +from .precision_recall_curve import BinaryPrecisionRecallCurve +from .pro import PRO +from .spro import SPRO +from .threshold import F1AdaptiveThreshold, ManualThreshold + +__all__ = [ + "AUROC", + "AUPR", + "AUPRO", + "AnomalyScoreDistribution", + "BinaryPrecisionRecallCurve", + "F1AdaptiveThreshold", + "F1Max", + "F1Score", + "ManualThreshold", + "MinMax", + "PRO", + "SPRO", +] + +logger = logging.getLogger(__name__) + + +def metric_collection_from_names(metric_names: list[str], prefix: str | None) -> AnomalibMetricCollection: + """Create a metric collection from a list of metric names. + + The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, + then in TorchMetrics package. + + Args: + metric_names (list[str]): List of metric names to be included in the collection. + prefix (str | None): prefix to assign to the metrics in the collection. + + Returns: + AnomalibMetricCollection: Collection of metrics. + """ + metrics_module = importlib.import_module("anomalib.metrics") + metrics = AnomalibMetricCollection([], prefix=prefix) + for metric_name in metric_names: + if hasattr(metrics_module, metric_name): + metric_cls = getattr(metrics_module, metric_name) + metrics.add_metrics(metric_cls()) + elif hasattr(torchmetrics, metric_name): + try: + metric_cls = getattr(torchmetrics, metric_name) + metrics.add_metrics(metric_cls()) + except TypeError: + msg = f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package." + logger.warning(msg) + else: + msg = f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics." + logger.warning(msg) + return metrics + + +def _validate_metrics_dict(metrics: dict[str, dict[str, Any]]) -> None: + """Check the assumptions about metrics config dict. + + - Keys are metric names + - Values are dictionaries. + - Internal dictionaries: + - have key "class_path" and its value is of type str + - have key init_args" and its value is of type dict). + + """ + if not all(isinstance(metric, str) for metric in metrics): + msg = f"All keys (metric names) must be strings, found {sorted(metrics.keys())}" + raise TypeError(msg) + + if not all(isinstance(metric, DictConfig | dict) for metric in metrics.values()): + msg = f"All values must be dictionaries, found {list(metrics.values())}" + raise TypeError(msg) + + if not all("class_path" in metric and isinstance(metric["class_path"], str) for metric in metrics.values()): + msg = "All internal dictionaries must have a 'class_path' key whose value is of type str." + raise ValueError(msg) + + if not all( + "init_args" in metric and isinstance(metric["init_args"], dict) or isinstance(metric["init_args"], DictConfig) + for metric in metrics.values() + ): + msg = "All internal dictionaries must have a 'init_args' key whose value is of type dict." + raise ValueError(msg) + + +def _get_class_from_path(class_path: str) -> Callable: + """Get a class from a module assuming the string format is `package.subpackage.module.ClassName`.""" + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + if not hasattr(module, class_name): + msg = f"Class {class_name} not found in module {module_name}" + raise AttributeError(msg) + return getattr(module, class_name) + + +def metric_collection_from_dicts(metrics: dict[str, dict[str, Any]], prefix: str | None) -> AnomalibMetricCollection: + """Create a metric collection from a dict of "metric name" -> "metric specifications". + + Example: + metrics = { + "PixelWiseF1Score": { + "class_path": "torchmetrics.F1Score", + "init_args": {}, + }, + "PixelWiseAUROC": { + "class_path": "anomalib.metrics.AUROC", + "init_args": { + }, + }, + } + + In the config file, the same specifications (for pixel-wise metrics) look like: + + ```yaml + metrics: + pixel: + PixelWiseF1Score: + class_path: torchmetrics.F1Score + init_args: {} + PixelWiseAUROC: + class_path: anomalib.metrics.AUROC + + ``` + + Args: + metrics (dict[str, dict[str, Any]]): keys are metric names, values are dictionaries. + Internal dict[str, Any] keys are "class_path" (value is string) and "init_args" (value is dict), + following the convention in Pytorch Lightning CLI. + + prefix (str | None): prefix to assign to the metrics in the collection. + + Returns: + AnomalibMetricCollection: Collection of metrics. + """ + _validate_metrics_dict(metrics) + metrics_collection = {} + for name, dict_ in metrics.items(): + class_path = dict_["class_path"] + kwargs = dict_["init_args"] + cls = _get_class_from_path(class_path) + metrics_collection[name] = cls(**kwargs) + return AnomalibMetricCollection(metrics_collection, prefix=prefix) + + +def create_metric_collection( + metrics: list[str] | dict[str, dict[str, Any]], + prefix: str | None = None, +) -> AnomalibMetricCollection: + """Create a metric collection from a list of metric names or dictionaries. + + This function will dispatch the actual creation to the appropriate function depending on the input type: + + - if list[str] (names of metrics): see `metric_collection_from_names` + - if dict[str, dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts` + + The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, + then in TorchMetrics package. + + Args: + metrics (list[str] | dict[str, dict[str, Any]]): List of metrics or dictionaries to create metric collection. + prefix (str | None): Prefix to assign to the metrics in the collection. + + Returns: + AnomalibMetricCollection: Collection of metrics. + """ + # fallback is using the names + + if isinstance(metrics, ListConfig | list): + if not all(isinstance(metric, str) for metric in metrics): + msg = f"All metrics must be strings, found {metrics}" + raise TypeError(msg) + + return metric_collection_from_names(metrics, prefix) + + if isinstance(metrics, DictConfig | dict): + _validate_metrics_dict(metrics) + return metric_collection_from_dicts(metrics, prefix) + + msg = f"metrics must be a list or a dict, found {type(metrics)}" + raise ValueError(msg) diff --git a/anomalib/metrics/anomaly_score_distribution.py b/anomalib/metrics/anomaly_score_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..4c629859a0c78392b35d851bd21898681a9f152b --- /dev/null +++ b/anomalib/metrics/anomaly_score_distribution.py @@ -0,0 +1,59 @@ +"""Module that computes the parameters of the normal data distribution of the training set.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torchmetrics import Metric + + +class AnomalyScoreDistribution(Metric): + """Mean and standard deviation of the anomaly scores of normal training data.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.anomaly_maps: list[torch.Tensor] = [] + self.anomaly_scores: list[torch.Tensor] = [] + + self.add_state("image_mean", torch.empty(0), persistent=True) + self.add_state("image_std", torch.empty(0), persistent=True) + self.add_state("pixel_mean", torch.empty(0), persistent=True) + self.add_state("pixel_std", torch.empty(0), persistent=True) + + self.image_mean = torch.empty(0) + self.image_std = torch.empty(0) + self.pixel_mean = torch.empty(0) + self.pixel_std = torch.empty(0) + + def update( + self, + *args, + anomaly_scores: torch.Tensor | None = None, + anomaly_maps: torch.Tensor | None = None, + **kwargs, + ) -> None: + """Update the precision-recall curve metric.""" + del args, kwargs # These variables are not used. + + if anomaly_maps is not None: + self.anomaly_maps.append(anomaly_maps) + if anomaly_scores is not None: + self.anomaly_scores.append(anomaly_scores) + + def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute stats.""" + anomaly_scores = torch.hstack(self.anomaly_scores) + anomaly_scores = torch.log(anomaly_scores) + + self.image_mean = anomaly_scores.mean() + self.image_std = anomaly_scores.std() + + if self.anomaly_maps: + anomaly_maps = torch.vstack(self.anomaly_maps) + anomaly_maps = torch.log(anomaly_maps).cpu() + + self.pixel_mean = anomaly_maps.mean(dim=0).squeeze() + self.pixel_std = anomaly_maps.std(dim=0).squeeze() + + return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std diff --git a/anomalib/metrics/aupr.py b/anomalib/metrics/aupr.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb3dc8bedc2665b15322fac61e2d10d1e299ec0 --- /dev/null +++ b/anomalib/metrics/aupr.py @@ -0,0 +1,109 @@ +"""Implementation of AUROC metric based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from matplotlib.figure import Figure +from torchmetrics.classification import BinaryPrecisionRecallCurve +from torchmetrics.utilities.compute import auc +from torchmetrics.utilities.data import dim_zero_cat + +from .plotting_utils import plot_figure + + +class AUPR(BinaryPrecisionRecallCurve): + """Area under the PR curve. + + This metric computes the area under the precision-recall curve. + + Args: + kwargs: Additional arguments to the TorchMetrics base class. + + Examples: + To compute the metric for a set of predictions and ground truth targets: + + >>> true = torch.tensor([0, 1, 1, 1, 0, 0, 0, 0, 1, 1]) + >>> pred = torch.tensor([0.59, 0.35, 0.72, 0.33, 0.73, 0.81, 0.30, 0.05, 0.04, 0.48]) + + >>> metric = AUPR() + >>> metric(pred, true) + tensor(0.4899) + + It is also possible to update the metric state incrementally within batches: + + >>> for batch in dataloader: + ... # Compute prediction and target tensors + ... metric.update(pred, true) + >>> metric.compute() + + Once the metric has been computed, we can plot the PR curve: + + >>> figure, title = metric.generate_figure() + """ + + def compute(self) -> torch.Tensor: + """First compute PR curve, then compute area under the curve. + + Returns: + Value of the AUPR metric + """ + prec: torch.Tensor + rec: torch.Tensor + + prec, rec = self._compute() + return auc(rec, prec, reorder=True) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Need to flatten new values as PrecicionRecallCurve expects them in this format for binary classification. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + super().update(preds.flatten(), target.flatten()) + + def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute prec/rec value pairs. + + Returns: + Tuple containing Tensors for rec and prec + """ + prec: torch.Tensor + rec: torch.Tensor + prec, rec, _ = super().compute() + return (prec, rec) + + def generate_figure(self) -> tuple[Figure, str]: + """Generate a figure containing the PR curve as well as the random baseline and the AUC. + + Returns: + tuple[Figure, str]: Tuple containing both the PR curve and the figure title to be used for logging + """ + prec, rec = self._compute() + aupr = self.compute() + + xlim = (0.0, 1.0) + ylim = (0.0, 1.0) + xlabel = "Precision" + ylabel = "Recall" + loc = "best" + title = "AUPR" + + fig, axis = plot_figure(rec, prec, aupr, xlim, ylim, xlabel, ylabel, loc, title) + + # Baseline in PR-curve is the prevalence of the positive class + rate = (dim_zero_cat(self.target) == 1).sum() / (dim_zero_cat(self.target).size(0)) + axis.plot( + (0, 1), + (rate.detach().cpu(), rate.detach().cpu()), + color="navy", + lw=2, + linestyle="--", + figure=fig, + ) + + return fig, title diff --git a/anomalib/metrics/aupro.py b/anomalib/metrics/aupro.py new file mode 100644 index 0000000000000000000000000000000000000000..0024769cc74e0b9f1f2268ae713fac8c262685b9 --- /dev/null +++ b/anomalib/metrics/aupro.py @@ -0,0 +1,294 @@ +"""Implementation of AUPRO score based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable +from typing import Any + +import torch +from matplotlib.figure import Figure +from torchmetrics import Metric +from torchmetrics.functional.classification import binary_roc +from torchmetrics.utilities.compute import auc +from torchmetrics.utilities.data import dim_zero_cat + +from anomalib.metrics.pro import connected_components_cpu, connected_components_gpu + +from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max +from .plotting_utils import plot_figure + + +class AUPRO(Metric): + """Area under per region overlap (AUPRO) Metric. + + Args: + dist_sync_on_step (bool): Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. Default: ``False`` + process_group (Optional[Any]): Specify the process group on which synchronization is called. + Default: ``None`` (which selects the entire world) + dist_sync_fn (Optional[Callable]): Callback that performs the allgather operation on the metric state. + When ``None``, DDP will be used to perform the allgather. + Default: ``None`` + fpr_limit (float): Limit for the false positive rate. Defaults to ``0.3``. + num_thresholds (int): Number of thresholds to use for computing the roc curve. Defaults to ``None``. + If ``None``, the roc curve is computed with the thresholds returned by + ``torchmetrics.functional.classification.thresholds``. + + Examples: + >>> import torch + >>> from anomalib.metrics import AUPRO + ... + >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) + >>> preds = torch.rand_like(labels) + ... + >>> aupro = AUPRO(fpr_limit=0.3) + >>> aupro(preds, labels) + tensor(0.4321) + + Increasing the fpr_limit will increase the AUPRO value: + + >>> aupro = AUPRO(fpr_limit=0.7) + >>> aupro(preds, labels) + tensor(0.5271) + """ + + is_differentiable: bool = False + higher_is_better: bool | None = None + full_state_update: bool = False + preds: list[torch.Tensor] + target: list[torch.Tensor] + # When not None, the computation is performed in constant-memory by computing the roc curve + # for fixed thresholds buckets/thresholds. + # Warning: The thresholds are evenly distributed between the min and max predictions + # if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1. + # This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed + # and the roc curve is computed with deactivated formatting + num_thresholds: int | None + + def __init__( + self, + dist_sync_on_step: bool = False, + process_group: Any | None = None, # noqa: ANN401 + dist_sync_fn: Callable | None = None, + fpr_limit: float = 0.3, + num_thresholds: int | None = None, + ) -> None: + super().__init__( + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) + self.num_thresholds = num_thresholds + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + self.target.append(target) + self.preds.append(preds) + + def perform_cca(self) -> torch.Tensor: + """Perform the Connected Component Analysis on the self.target tensor. + + Raises: + ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for + connected component analysis. + + Returns: + Tensor: Components labeled from 0 to N. + """ + target = dim_zero_cat(self.target) + + # check and prepare target for labeling via kornia + if target.min() < 0 or target.max() > 1: + msg = ( + "kornia.contrib.connected_components expects input to lie in the interval [0, 1], " + f"but found interval was [{target.min()}, {target.max()}]." + ) + raise ValueError( + msg, + ) + target = target.unsqueeze(1) # kornia expects N1HW format + target = target.type(torch.float) # kornia expects FloatTensor + return connected_components_gpu(target) if target.is_cuda else connected_components_cpu(target) + + def compute_pro( + self, + cca: torch.Tensor, + target: torch.Tensor, + preds: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. + + It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall + PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. + + Returns: + tuple[torch.Tensor, torch.Tensor]: tuple containing final fpr and tpr values. + """ + if self.num_thresholds is not None: + # binary_roc is applying a sigmoid on the predictions before computing the roc curve + # when some predictions are out of [0, 1], the binning between min and max predictions + # cannot be applied in that case. This can be removed when + # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and + # the roc curve is computed with deactivated formatting. + + if torch.all((preds >= 0) * (preds <= 1)): + thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device) + else: + thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device) + + else: + thresholds = None + + # compute the global fpr-size + fpr: torch.Tensor = binary_roc( + preds=preds, + target=target, + thresholds=thresholds, + )[0] # only need fpr + output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) + + # compute the PRO curve by aggregating per-region tpr/fpr curves/values. + tpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) + fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) + new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float) + + # Loop over the labels, computing per-region tpr/fpr curves, and aggregating them. + # Note that, since the groundtruth is different for every all to `roc`, we also get + # different/unique tpr/fpr curves (i.e. len(_fpr_idx) is different for every call). + # We therefore need to resample per-region curves to a fixed sampling ratio (defined above). + labels = cca.unique()[1:] # 0 is background + background = cca == 0 + _fpr: torch.Tensor + _tpr: torch.Tensor + for label in labels: + interp: bool = False + new_idx[-1] = output_size - 1 + mask = cca == label + # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other + # label in labels as FPs. We also don't need to return the thresholds + _fpr, _tpr = binary_roc( + preds=preds[background | mask], + target=mask[background | mask], + thresholds=thresholds, + )[:-1] + + # catch edge-case where ROC only has fpr vals > self.fpr_limit + if _fpr[_fpr <= self.fpr_limit].max() == 0: + _fpr_limit = _fpr[_fpr > self.fpr_limit].min() + else: + _fpr_limit = self.fpr_limit + + _fpr_idx = torch.where(_fpr <= _fpr_limit)[0] + # if computed roc curve is not specified sufficiently close to self.fpr_limit, + # we include the closest higher tpr/fpr pair and linearly interpolate the tpr/fpr point at self.fpr_limit + if not torch.allclose(_fpr[_fpr_idx].max(), self.fpr_limit): + _tmp_idx = torch.searchsorted(_fpr, self.fpr_limit) + _fpr_idx = torch.cat([_fpr_idx, _tmp_idx.unsqueeze_(0)]) + _slope = 1 - ((_fpr[_tmp_idx] - self.fpr_limit) / (_fpr[_tmp_idx] - _fpr[_tmp_idx - 1])) + interp = True + + _fpr = _fpr[_fpr_idx] + _tpr = _tpr[_fpr_idx] + + _fpr_idx = _fpr_idx.float() + _fpr_idx /= _fpr_idx.max() + _fpr_idx *= new_idx.max() + + if interp: + # last point will be sampled at self.fpr_limit + new_idx[-1] = _fpr_idx[-2] + ((_fpr_idx[-1] - _fpr_idx[-2]) * _slope) + + _tpr = self.interp1d(_fpr_idx, _tpr, new_idx) + _fpr = self.interp1d(_fpr_idx, _fpr, new_idx) + tpr += _tpr + fpr += _fpr + + # Actually perform the averaging + tpr /= labels.size(0) + fpr /= labels.size(0) + return fpr, tpr + + def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the PRO curve. + + Perform the Connected Component Analysis first then compute the PRO curve. + + Returns: + tuple[torch.Tensor, torch.Tensor]: tuple containing final fpr and tpr values. + """ + cca = self.perform_cca().flatten() + target = dim_zero_cat(self.target).flatten() + preds = dim_zero_cat(self.preds).flatten() + + return self.compute_pro(cca=cca, target=target, preds=preds) + + def compute(self) -> torch.Tensor: + """Fist compute PRO curve, then compute and scale area under the curve. + + Returns: + Tensor: Value of the AUPRO metric + """ + fpr, tpr = self._compute() + + aupro = auc(fpr, tpr, reorder=True) + return aupro / fpr[-1] # normalize the area + + def generate_figure(self) -> tuple[Figure, str]: + """Generate a figure containing the PRO curve and the AUPRO. + + Returns: + tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging + """ + fpr, tpr = self._compute() + aupro = self.compute() + + xlim = (0.0, self.fpr_limit.detach_().cpu().numpy()) + ylim = (0.0, 1.0) + xlabel = "Global FPR" + ylabel = "Averaged Per-Region TPR" + loc = "lower right" + title = "PRO" + + fig, _axis = plot_figure(fpr, tpr, aupro, xlim, ylim, xlabel, ylabel, loc, title) + + return fig, "PRO" + + @staticmethod + def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> torch.Tensor: + """Interpolate a 1D signal linearly to new sampling points. + + Args: + old_x (torch.Tensor): original 1-D x values (same size as y) + old_y (torch.Tensor): original 1-D y values (same size as x) + new_x (torch.Tensor): x-values where y should be interpolated at + + Returns: + Tensor: y-values at corresponding new_x values. + """ + # Compute slope + eps = torch.finfo(old_y.dtype).eps + slope = (old_y[1:] - old_y[:-1]) / (eps + (old_x[1:] - old_x[:-1])) + + # Prepare idx for linear interpolation + idx = torch.searchsorted(old_x, new_x) + + # searchsorted looks for the index where the values must be inserted + # to preserve order, but we actually want the preceeding index. + idx -= 1 + # we clamp the index, because the number of intervals = old_x.size(0) -1, + # and the left neighbour should hence be at most number of intervals -1, i.e. old_x.size(0) - 2 + idx = torch.clamp(idx, 0, old_x.size(0) - 2) + + # perform actual linear interpolation + return old_y[idx] + slope[idx] * (new_x - old_x[idx]) diff --git a/anomalib/metrics/auroc.py b/anomalib/metrics/auroc.py new file mode 100644 index 0000000000000000000000000000000000000000..fb714110634ab6ddb0d2539c5f8059fe6eaa2dcb --- /dev/null +++ b/anomalib/metrics/auroc.py @@ -0,0 +1,102 @@ +"""Implementation of AUROC metric based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from matplotlib.figure import Figure +from torchmetrics.classification.roc import BinaryROC +from torchmetrics.utilities.compute import auc + +from .plotting_utils import plot_figure + + +class AUROC(BinaryROC): + """Area under the ROC curve. + + Examples: + >>> import torch + >>> from anomalib.metrics import AUROC + ... + >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.92, 0.03]) + >>> target = torch.tensor([0, 0, 1, 1, 0]) + ... + >>> auroc = AUROC() + >>> auroc(preds, target) + tensor(0.6667) + + It is possible to update the metric state incrementally: + + >>> auroc.update(preds[:2], target[:2]) + >>> auroc.update(preds[2:], target[2:]) + >>> auroc.compute() + tensor(0.6667) + + To plot the ROC curve, use the ``generate_figure`` method: + + >>> fig, title = auroc.generate_figure() + """ + + def compute(self) -> torch.Tensor: + """First compute ROC curve, then compute area under the curve. + + Returns: + Tensor: Value of the AUROC metric + """ + tpr: torch.Tensor + fpr: torch.Tensor + + fpr, tpr = self._compute() + return auc(fpr, tpr, reorder=True) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Need to flatten new values as ROC expects them in this format for binary classification. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + super().update(preds.flatten(), target.flatten()) + + def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute fpr/tpr value pairs. + + Returns: + Tuple containing Tensors for fpr and tpr + """ + tpr: torch.Tensor + fpr: torch.Tensor + fpr, tpr, _thresholds = super().compute() + return (fpr, tpr) + + def generate_figure(self) -> tuple[Figure, str]: + """Generate a figure containing the ROC curve, the baseline and the AUROC. + + Returns: + tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging + """ + fpr, tpr = self._compute() + auroc = self.compute() + + xlim = (0.0, 1.0) + ylim = (0.0, 1.0) + xlabel = "False Positive Rate" + ylabel = "True Positive Rate" + loc = "lower right" + title = "ROC" + + fig, axis = plot_figure(fpr, tpr, auroc, xlim, ylim, xlabel, ylabel, loc, title) + + axis.plot( + [0, 1], + [0, 1], + color="navy", + lw=2, + linestyle="--", + figure=fig, + ) + + return fig, title diff --git a/anomalib/metrics/binning.py b/anomalib/metrics/binning.py new file mode 100644 index 0000000000000000000000000000000000000000..b56c23480064a66ea9d8183fc84c12af329068c4 --- /dev/null +++ b/anomalib/metrics/binning.py @@ -0,0 +1,40 @@ +"""Binning functions for metrics.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import linspace + + +def thresholds_between_min_and_max( + preds: torch.Tensor, + num_thresholds: int = 100, + device: torch.device | None = None, +) -> torch.Tensor: + """Threshold values between min and max of the predictions. + + Args: + preds (torch.Tensor): Predictions. + num_thresholds (int, optional): Number of thresholds to generate. Defaults to 100. + device (torch_device | None, optional): Device to use for computation. Defaults to None. + + Returns: + Tensor: + Array of size ``num_thresholds`` that contains evenly spaced values + between ``preds.min()`` and ``preds.max()`` on ``device``. + """ + return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) + + +def thresholds_between_0_and_1(num_thresholds: int = 100, device: torch.device | None = None) -> torch.Tensor: + """Threshold values between 0 and 1. + + Args: + num_thresholds (int, optional): Number of thresholds to generate. Defaults to 100. + device (torch_device | None, optional): Device to use for computation. Defaults to None. + + Returns: + Tensor: Threshold values between 0 and 1. + """ + return linspace(start=0, end=1, steps=num_thresholds, device=device) diff --git a/anomalib/metrics/collection.py b/anomalib/metrics/collection.py new file mode 100644 index 0000000000000000000000000000000000000000..47c17a3a442273b28ca761b61fccbdb0c9e2d998 --- /dev/null +++ b/anomalib/metrics/collection.py @@ -0,0 +1,45 @@ +"""Anomalib Metric Collection.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from torchmetrics import MetricCollection + +logger = logging.getLogger(__name__) + + +class AnomalibMetricCollection(MetricCollection): + """Extends the MetricCollection class for use in the Anomalib pipeline.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._update_called = False + self._threshold = 0.5 + + def set_threshold(self, threshold_value: float) -> None: + """Update the threshold value for all metrics that have the threshold attribute.""" + self._threshold = threshold_value + for metric in self.values(): + if hasattr(metric, "threshold"): + metric.threshold = threshold_value + + def set_update_called(self, val: bool) -> None: + """Set the flag indicating whether the update method has been called.""" + self._update_called = val + + def update(self, *args, **kwargs) -> None: + """Add data to the metrics.""" + super().update(*args, **kwargs) + self._update_called = True + + @property + def update_called(self) -> bool: + """Returns a boolean indicating if the update method has been called at least once.""" + return self._update_called + + @property + def threshold(self) -> float: + """Return the value of the anomaly threshold.""" + return self._threshold diff --git a/anomalib/metrics/f1_max.py b/anomalib/metrics/f1_max.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9b42f3051f2cc8049b448228ec820aae37844d --- /dev/null +++ b/anomalib/metrics/f1_max.py @@ -0,0 +1,100 @@ +"""Implementation of F1Max score based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torchmetrics import Metric + +from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve + +logger = logging.getLogger(__name__) + + +class F1Max(Metric): + """F1Max Metric for Computing the Maximum F1 Score. + + This class is designed to calculate the maximum F1 score from the precision- + recall curve for binary classification tasks. The F1 score is a harmonic + mean of precision and recall, offering a balance between these two metrics. + The maximum F1 score (F1-Max) is particularly useful in scenarios where an + optimal balance between precision and recall is desired, such as in + imbalanced datasets or when both false positives and false negatives carry + significant costs. + + After computing the F1Max score, the class also identifies and stores the + threshold that yields this maximum F1 score, which providing insight into + the optimal point for the classification decision. + + Args: + **kwargs: Variable keyword arguments that can be passed to the parent class. + + Attributes: + full_state_update (bool): Indicates whether the metric requires updating + the entire state. Set to False for this metric as it calculates the + F1 score based on the current state without needing historical data. + precision_recall_curve (BinaryPrecisionRecallCurve): Utility to compute + precision and recall values across different thresholds. + threshold (torch.Tensor): Stores the threshold value that results in the + maximum F1 score. + + Examples: + >>> from anomalib.metrics import F1Max + >>> import torch + + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + + >>> f1_max = F1Max() + >>> f1_max.update(preds, target) + + >>> optimal_f1_score = f1_max.compute() + >>> print(f"Optimal F1 Score: {f1_max_score}") + >>> print(f"Optimal Threshold: {f1_max.threshold}") + + Note: + - Use `update` method to input predictions and target labels. + - Use `compute` method to calculate the maximum F1 score after all + updates. + - Use `reset` method to clear the current state and prepare for a new + set of calculations. + """ + + full_state_update: bool = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.precision_recall_curve = BinaryPrecisionRecallCurve() + + self.threshold: torch.Tensor + + def update(self, preds: torch.Tensor, target: torch.Tensor, *args, **kwargs) -> None: + """Update the precision-recall curve metric.""" + del args, kwargs # These variables are not used. + + self.precision_recall_curve.update(preds, target) + + def compute(self) -> torch.Tensor: + """Compute the value of the optimal F1 score. + + Compute the F1 scores while varying the threshold. Store the optimal + threshold as attribute and return the maximum value of the F1 score. + + Returns: + Value of the F1 score at the optimal threshold. + """ + precision: torch.Tensor + recall: torch.Tensor + thresholds: torch.Tensor + + precision, recall, thresholds = self.precision_recall_curve.compute() + f1_score = (2 * precision * recall) / (precision + recall + 1e-10) + self.threshold = thresholds[torch.argmax(f1_score)] + return torch.max(f1_score) + + def reset(self) -> None: + """Reset the metric.""" + self.precision_recall_curve.reset() diff --git a/anomalib/metrics/f1_score.py b/anomalib/metrics/f1_score.py new file mode 100644 index 0000000000000000000000000000000000000000..0477e8306d328454d9733679a7b77e37c30a15aa --- /dev/null +++ b/anomalib/metrics/f1_score.py @@ -0,0 +1,37 @@ +"""F1 Score metric. + +This is added for convenience. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any, Literal + +from torchmetrics.classification import BinaryF1Score + +logger = logging.getLogger(__name__) + + +class F1Score(BinaryF1Score): + """This is a wrapper around torchmetrics' BinaryF1Score. + + The idea behind this is to retain the current configuration otherwise the one from + torchmetrics requires ``task`` as a parameter. + """ + + def __init__( + self, + threshold: float = 0.5, + multidim_average: Literal["global"] | Literal["samplewise"] = "global", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, # noqa: ANN401 + ) -> None: + super().__init__(threshold, multidim_average, ignore_index, validate_args, **kwargs) + logger.warning( + "F1Score class exists for backwards compatibility. It will be removed in v1.1." + " Please use BinaryF1Score from torchmetrics instead", + ) diff --git a/anomalib/metrics/min_max.py b/anomalib/metrics/min_max.py new file mode 100644 index 0000000000000000000000000000000000000000..77a76afaaa546ce6663492f27baa3c874ad5842c --- /dev/null +++ b/anomalib/metrics/min_max.py @@ -0,0 +1,56 @@ +"""Module that tracks the min and max values of the observations in each batch.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torchmetrics import Metric + + +class MinMax(Metric): + """Track the min and max values of the observations in each batch. + + Args: + full_state_update (bool, optional): Whether to update the state with the + new values. + Defaults to ``True``. + kwargs: Any keyword arguments. + + Examples: + >>> from anomalib.metrics import MinMax + >>> import torch + ... + >>> predictions = torch.tensor([0.0807, 0.6329, 0.0559, 0.9860, 0.3595]) + >>> minmax = MinMax() + >>> minmax(predictions) + (tensor(0.0559), tensor(0.9860)) + + It is possible to update the minmax values with a new tensor of predictions. + + >>> new_predictions = torch.tensor([0.3251, 0.3169, 0.3072, 0.6247, 0.9999]) + >>> minmax.update(new_predictions) + >>> minmax.compute() + (tensor(0.0559), tensor(0.9999)) + """ + + full_state_update: bool = True + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state("min", torch.tensor(float("inf")), persistent=True) + self.add_state("max", torch.tensor(float("-inf")), persistent=True) + + self.min = torch.tensor(float("inf")) + self.max = torch.tensor(float("-inf")) + + def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: + """Update the min and max values.""" + del args, kwargs # These variables are not used. + + self.max = torch.max(self.max, torch.max(predictions)) + self.min = torch.min(self.min, torch.min(predictions)) + + def compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Return min and max values.""" + return self.min, self.max diff --git a/anomalib/metrics/plotting_utils.py b/anomalib/metrics/plotting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d304b3f9f5e6e0cb80a3e17b010a9065659f3d92 --- /dev/null +++ b/anomalib/metrics/plotting_utils.py @@ -0,0 +1,82 @@ +"""Helper functions to generate ROC-style plots of various metrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from matplotlib import pyplot as plt +from matplotlib.axis import Axis +from matplotlib.figure import Figure + + +def plot_figure( + x_vals: torch.Tensor, + y_vals: torch.Tensor, + auc: torch.Tensor, + xlim: tuple[float, float], + ylim: tuple[float, float], + xlabel: str, + ylabel: str, + loc: str, + title: str, + sample_points: int = 1000, +) -> tuple[Figure, Axis]: + """Generate a simple, ROC-style plot, where x_vals is plotted against y_vals. + + Note that a subsampling is applied if > sample_points are present in x/y, as matplotlib plotting draws + every single plot which takes very long, especially for high-resolution segmentations. + + Args: + x_vals (torch.Tensor): x values to plot + y_vals (torch.Tensor): y values to plot + auc (torch.Tensor): normalized area under the curve spanned by x_vals, y_vals + xlim (tuple[float, float]): displayed range for x-axis + ylim (tuple[float, float]): displayed range for y-axis + xlabel (str): label of x axis + ylabel (str): label of y axis + loc (str): string-based legend location, for details see + https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html + title (str): title of the plot + sample_points (int): number of sampling points to subsample x_vals/y_vals with + Defaults to ``1000``. + + Returns: + tuple[Figure, Axis]: Figure and the contained Axis + """ + fig, axis = plt.subplots() + + x_vals = x_vals.detach().cpu() + y_vals = y_vals.detach().cpu() + + if sample_points < x_vals.size(0): + possible_idx = range(x_vals.size(0)) + interval = len(possible_idx) // sample_points + + idx = [0] # make sure to start at first point + idx.extend(possible_idx[::interval]) + idx.append(possible_idx[-1]) # also include last point + + idx = torch.tensor( + idx, + device=x_vals.device, + ) + x_vals = torch.index_select(x_vals, 0, idx) + y_vals = torch.index_select(y_vals, 0, idx) + + axis.plot( + x_vals, + y_vals, + color="darkorange", + figure=fig, + lw=2, + label=f"AUC: {auc.detach().cpu():0.2f}", + ) + + axis.set_xlim(xlim) + axis.set_ylim(ylim) + axis.set_xlabel(xlabel) + axis.set_ylabel(ylabel) + axis.legend(loc=loc) + axis.set_title(title) + return fig, axis diff --git a/anomalib/metrics/precision_recall_curve.py b/anomalib/metrics/precision_recall_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..10b6f7aa274b124c79b4d7de326fa05c346bc60e --- /dev/null +++ b/anomalib/metrics/precision_recall_curve.py @@ -0,0 +1,60 @@ +"""Custom PrecisionRecallCurve. + +The one in torchmetrics adds a sigmoid operation on top of the thresholds. +See: https://github.com/Lightning-AI/torchmetrics/issues/1526 +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from torch import Tensor +from torchmetrics.classification import BinaryPrecisionRecallCurve as _BinaryPrecisionRecallCurve +from torchmetrics.functional.classification.precision_recall_curve import ( + _adjust_threshold_arg, + _binary_precision_recall_curve_update, +) + + +class BinaryPrecisionRecallCurve(_BinaryPrecisionRecallCurve): + """Binary precision-recall curve with without threshold prediction normalization.""" + + @staticmethod + def _binary_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + thresholds: int | list[float] | Tensor | None = None, + ignore_index: int | None = None, + ) -> tuple[Tensor, Tensor, Tensor | None]: + """Similar to torchmetrics' ``_binary_precision_recall_curve_format`` except it does not apply sigmoid.""" + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + thresholds = _adjust_threshold_arg(thresholds, preds.device) + return preds, target, thresholds + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update metric state with new predictions and targets. + + Unlike the base class, this accepts raw predictions and targets. + + Args: + preds (Tensor): Predicted probabilities + target (Tensor): Ground truth labels + """ + preds, target, _ = BinaryPrecisionRecallCurve._binary_precision_recall_curve_format( + preds, + target, + self.thresholds, + self.ignore_index, + ) + state = _binary_precision_recall_curve_update(preds, target, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) diff --git a/anomalib/metrics/pro.py b/anomalib/metrics/pro.py new file mode 100644 index 0000000000000000000000000000000000000000..17743374081ccd68ecf3f0fbee2f7aa8cdea97e5 --- /dev/null +++ b/anomalib/metrics/pro.py @@ -0,0 +1,126 @@ +"""Implementation of PRO metric based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torchmetrics import Metric +from torchmetrics.functional import recall +from torchmetrics.utilities.data import dim_zero_cat + +from anomalib.utils.cv import connected_components_cpu, connected_components_gpu + + +class PRO(Metric): + """Per-Region Overlap (PRO) Score. + + This metric computes the macro average of the per-region overlap between the + predicted anomaly masks and the ground truth masks. + + Args: + threshold (float): Threshold used to binarize the predictions. + Defaults to ``0.5``. + kwargs: Additional arguments to the TorchMetrics base class. + + Example: + Import the metric from the package: + + >>> import torch + >>> from anomalib.metrics import PRO + + Create random ``preds`` and ``labels`` tensors: + + >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) + >>> preds = torch.rand_like(labels) + + Compute the PRO score for labels and preds: + + >>> pro = PRO(threshold=0.5) + >>> pro.update(preds, labels) + >>> pro.compute() + tensor(0.5433) + + .. note:: + Note that the example above shows random predictions and labels. + Therefore, the PRO score above may not be reproducible. + + """ + + target: list[torch.Tensor] + preds: list[torch.Tensor] + + def __init__(self, threshold: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None: + """Compute the PRO score for the current batch. + + Args: + predictions (torch.Tensor): Predicted anomaly masks (Bx1xHxW) + targets (torch.Tensor): Ground truth anomaly masks (Bx1xHxW) + + Example: + To update the metric state for the current batch, use the ``update`` method: + + >>> pro.update(preds, labels) + """ + self.target.append(targets) + self.preds.append(predictions) + + def compute(self) -> torch.Tensor: + """Compute the macro average of the PRO score across all regions in all batches. + + Example: + To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: + + >>> pro.compute() + tensor(0.5433) + """ + target = dim_zero_cat(self.target) + preds = dim_zero_cat(self.preds) + + target = target.unsqueeze(1).type(torch.float) # kornia expects N1HW and FloatTensor format + comps = connected_components_gpu(target) if target.is_cuda else connected_components_cpu(target) + return pro_score(preds, comps, threshold=self.threshold) + + +def pro_score(predictions: torch.Tensor, comps: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """Calculate the PRO score for a batch of predictions. + + Args: + predictions (torch.Tensor): Predicted anomaly masks (Bx1xHxW) + comps: (torch.Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + + Returns: + torch.Tensor: Scalar value representing the average PRO score for the input batch. + """ + if predictions.dtype == torch.float: + predictions = predictions > threshold + + n_comps = len(comps.unique()) + + preds = comps.clone() + # match the shapes in case one of the tensors is N1HW + preds = preds.reshape(predictions.shape) + preds[~predictions] = 0 + if n_comps == 1: # only background + return torch.Tensor([1.0]) + + # Even though ignore_index is set to 0, the final average computed with "macro" + # takes the entire length of the tensor into account. That's why we need to manually + # subtract 1 from the number of components after taking the sum + recall_tensor = recall( + preds.flatten(), + comps.flatten(), + task="multiclass", + num_classes=n_comps, + average=None, + ignore_index=0, + ) + return recall_tensor.sum() / (n_comps - 1) diff --git a/anomalib/metrics/spro.py b/anomalib/metrics/spro.py new file mode 100644 index 0000000000000000000000000000000000000000..c59091ee5f2bda550076edf4294f9a8bf01b5da8 --- /dev/null +++ b/anomalib/metrics/spro.py @@ -0,0 +1,215 @@ +"""Implementation of SPRO metric based on TorchMetrics.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from pathlib import Path +from typing import Any + +import torch +from torchmetrics import Metric + +from anomalib.data.utils import validate_path + +logger = logging.getLogger(__name__) + + +class SPRO(Metric): + """Saturated Per-Region Overlap (SPRO) Score. + + This metric computes the macro average of the saturated per-region overlap between the + predicted anomaly masks and the ground truth masks. + + Args: + threshold (float): Threshold used to binarize the predictions. + Defaults to ``0.5``. + saturation_config (str | Path): Path to the saturation configuration file. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. + kwargs: Additional arguments to the TorchMetrics base class. + + Example: + Import the metric from the package: + + >>> import torch + >>> from anomalib.metrics import SPRO + + Create random ``preds`` and ``labels`` tensors: + + >>> labels = torch.randint(low=0, high=2, size=(2, 10, 5), dtype=torch.float32) + >>> labels = [labels] + >>> preds = torch.rand_like(labels[0][:1]) + + Compute the SPRO score for labels and preds: + + >>> spro = SPRO(threshold=0.5) + >>> spro.update(preds, labels) + >>> spro.compute() + tensor(0.6333) + + .. note:: + Note that the example above shows random predictions and labels. + Therefore, the SPRO score above may not be reproducible. + + """ + + def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None + if self.saturation_config is None: + logger.warning( + "The saturation_config attribute is empty, the threshold is set to the defect area." + "This is equivalent to PRO metric but with the 'region' are separated by mask files", + ) + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None: + """Compute the SPRO score for the current batch. + + Args: + predictions (torch.Tensor): Predicted anomaly masks. + masks (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. + + Example: + To update the metric state for the current batch, use the ``update`` method: + + >>> spro.update(preds, labels) + """ + score, total = spro_score( + predictions=predictions, + targets=masks, + threshold=self.threshold, + saturation_config=self.saturation_config, + ) + self.score += score + self.total += total + + def compute(self) -> torch.Tensor: + """Compute the macro average of the SPRO score across all masks in all batches. + + Example: + To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: + + >>> spro.compute() + tensor(0.5433) + """ + if self.total == 0: # only background/normal images + return torch.Tensor([1.0]) + return self.score / self.total + + +def spro_score( + predictions: torch.Tensor, + targets: list[torch.Tensor], + threshold: float = 0.5, + saturation_config: dict | None = None, +) -> torch.Tensor: + """Calculate the SPRO score for a batch of predictions. + + Args: + predictions (torch.Tensor): Predicted anomaly masks. + targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + Defaults: ``0.5``. + saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. + + Returns: + torch.Tensor: Scalar value representing the average SPRO score for the input batch. + """ + # Add batch dim if not exist + if len(predictions.shape) == 2: + predictions = predictions.unsqueeze(0) + + # Resize the prediction to have the same size as the target mask + predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets[0].shape[-2:]) + + # Apply threshold to binary predictions + if predictions.dtype == torch.float: + predictions = predictions > threshold + + score = torch.tensor(0.0) + total = 0 + # Iterate for each image in the batch + for i, target in enumerate(targets): + # Iterate for each ground-truth mask per image + for mask in target: + label = torch.max(mask) + if label == 0: # Skip if only normal/background + continue + # Calculate true positive + target_per_label = mask == label + true_pos = torch.sum(predictions[i] & target_per_label) + + # Calculate the anomalous area of the ground-truth + defect_area = torch.sum(target_per_label) + + if saturation_config is not None: + # Adjust saturation threshold based on configuration + saturation_per_label = saturation_config[label.int().item()] + saturation_threshold = saturation_per_label["saturation_threshold"] + + if saturation_per_label["relative_saturation"]: + saturation_threshold *= defect_area + + # Check if threshold is larger than defect area + if saturation_threshold > defect_area: + warning_msg = ( + f"Saturation threshold for label {label.int().item()} is larger than defect area. " + "Setting it to defect area." + ) + logger.warning(warning_msg) + saturation_threshold = defect_area + else: + # Handle case when saturation_config is empty + saturation_threshold = defect_area + + # Update score with minimum of true_pos/saturation_threshold and 1.0 + score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) + total += 1 + return score, total + + +def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: + """Load saturation configurations from a JSON file. + + Args: + config_path (str | Path): Path to the saturation configuration file. + + Returns: + Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. + + Example JSON format in the config file of MVTec LOCO dataset: + [ + { + "defect_name": "1_additional_pushpin", + "pixel_value": 255, + "saturation_threshold": 6300, + "relative_saturation": false + }, + { + "defect_name": "2_additional_pushpins", + "pixel_value": 254, + "saturation_threshold": 12600, + "relative_saturation": false + }, + ... + ] + """ + try: + config_path = validate_path(config_path) + with Path.open(config_path) as file: + configs = json.load(file) + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + except FileNotFoundError: + logger.warning("The saturation config file %s does not exist. Returning None.", config_path) + return None diff --git a/anomalib/metrics/threshold/__init__.py b/anomalib/metrics/threshold/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfa996cac6daec6a62666130eea308a6689a408 --- /dev/null +++ b/anomalib/metrics/threshold/__init__.py @@ -0,0 +1,10 @@ +"""Thresholding metrics.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .base import BaseThreshold +from .f1_adaptive_threshold import F1AdaptiveThreshold +from .manual_threshold import ManualThreshold + +__all__ = ["BaseThreshold", "F1AdaptiveThreshold", "ManualThreshold"] diff --git a/anomalib/metrics/threshold/base.py b/anomalib/metrics/threshold/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6bee389b3caf8422b6026103c8a873ecea79bbad --- /dev/null +++ b/anomalib/metrics/threshold/base.py @@ -0,0 +1,35 @@ +"""Base class for thresholding metrics.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC + +import torch +from torchmetrics import Metric + + +class BaseThreshold(Metric, ABC): + """Base class for thresholding metrics.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def compute(self) -> torch.Tensor: + """Compute the threshold. + + Returns: + Value of the optimal threshold. + """ + msg = "Subclass of BaseAnomalyScoreThreshold must implement the compute method" + raise NotImplementedError(msg) + + def update(self, *args, **kwargs) -> None: # noqa: ARG002 + """Update the metric state. + + Args: + *args: Any positional arguments. + **kwargs: Any keyword arguments. + """ + msg = "Subclass of BaseAnomalyScoreThreshold must implement the update method" + raise NotImplementedError(msg) diff --git a/anomalib/metrics/threshold/f1_adaptive_threshold.py b/anomalib/metrics/threshold/f1_adaptive_threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..07b962c81898ce3d36878f893b09c6c0355e5ffe --- /dev/null +++ b/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -0,0 +1,89 @@ +"""Implementation of F1AdaptiveThreshold based on TorchMetrics.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +import torch + +from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve + +from .base import BaseThreshold + +logger = logging.getLogger(__name__) + + +class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, BaseThreshold): + """Anomaly Score Threshold. + + This class computes/stores the threshold that determines the anomalous label + given anomaly scores. It initially computes the adaptive threshold to find + the optimal f1_score and stores the computed adaptive threshold value. + + Args: + default_value: Default value of the threshold. + Defaults to ``0.5``. + + Examples: + To find the best threshold that maximizes the F1 score, we could run the + following: + + >>> from anomalib.metrics import F1AdaptiveThreshold + >>> import torch + ... + >>> labels = torch.tensor([0, 0, 0, 1, 1]) + >>> preds = torch.tensor([2.3, 1.6, 2.6, 7.9, 3.3]) + ... + >>> adaptive_threshold = F1AdaptiveThreshold(default_value=0.5) + >>> threshold = adaptive_threshold(preds, labels) + >>> threshold + tensor(3.3000) + """ + + def __init__(self, default_value: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + + self.add_state("value", default=torch.tensor(default_value), persistent=True) + self.value = torch.tensor(default_value) + + def compute(self) -> torch.Tensor: + """Compute the threshold that yields the optimal F1 score. + + Compute the F1 scores while varying the threshold. Store the optimal + threshold as attribute and return the maximum value of the F1 score. + + Returns: + Value of the F1 score at the optimal threshold. + """ + precision: torch.Tensor + recall: torch.Tensor + thresholds: torch.Tensor + + if not any(1 in batch for batch in self.target): + msg = ( + "The validation set does not contain any anomalous images. As a result, the adaptive threshold will " + "take the value of the highest anomaly score observed in the normal validation images, which may lead " + "to poor predictions. For a more reliable adaptive threshold computation, please add some anomalous " + "images to the validation set." + ) + logging.warning(msg) + + precision, recall, thresholds = super().compute() + f1_score = (2 * precision * recall) / (precision + recall + 1e-10) + if thresholds.dim() == 0: + # special case where recall is 1.0 even for the highest threshold. + # In this case 'thresholds' will be scalar. + self.value = thresholds + else: + self.value = thresholds[torch.argmax(f1_score)] + return self.value + + def __repr__(self) -> str: + """Return threshold value within the string representation. + + Returns: + str: String representation of the class. + """ + return f"{super().__repr__()} (value={self.value:.2f})" diff --git a/anomalib/metrics/threshold/manual_threshold.py b/anomalib/metrics/threshold/manual_threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc19d23044f2a2c14f28bd8fdec56089131687b --- /dev/null +++ b/anomalib/metrics/threshold/manual_threshold.py @@ -0,0 +1,67 @@ +"""Container to hold manual threshold values for image and pixel metrics.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +from .base import BaseThreshold + + +class ManualThreshold(BaseThreshold): + """Initialize Manual Threshold. + + Args: + default_value (float, optional): Default threshold value. + Defaults to ``0.5``. + kwargs: Any keyword arguments. + + Examples: + >>> from anomalib.metrics import ManualThreshold + >>> import torch + ... + >>> manual_threshold = ManualThreshold(default_value=0.5) + ... + >>> labels = torch.randint(low=0, high=2, size=(5,)) + >>> preds = torch.rand(5) + ... + >>> threshold = manual_threshold(preds, labels) + >>> threshold + tensor(0.5000, dtype=torch.float64) + + As the threshold is manually set, the threshold value is the same as the + ``default_value``. + + >>> labels = torch.randint(low=0, high=2, size=(5,)) + >>> preds = torch.rand(5) + >>> threshold = manual_threshold(preds2, labels2) + >>> threshold + tensor(0.5000, dtype=torch.float64) + + The threshold value remains the same even if the inputs change. + """ + + def __init__(self, default_value: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state("value", default=torch.tensor(default_value, dtype=torch.float64), persistent=True) + self.value = torch.tensor(default_value, dtype=torch.float64) + + def compute(self) -> torch.Tensor: + """Compute the threshold. + + In case of manual thresholding, the threshold is already set and does not need to be computed. + + Returns: + torch.Tensor: Value of the optimal threshold. + """ + return self.value + + def update(self, *args, **kwargs) -> None: + """Do nothing. + + Args: + *args: Any positional arguments. + **kwargs: Any keyword arguments. + """ + del args, kwargs # Unused arguments. diff --git a/anomalib/models/__init__.py b/anomalib/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..722cd1dfe525f55312105414f0fe2dbd196c8032 --- /dev/null +++ b/anomalib/models/__init__.py @@ -0,0 +1,178 @@ +"""Load Anomaly Model.""" + +# Copyright (C) 2022-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from importlib import import_module + +from jsonargparse import Namespace +from omegaconf import DictConfig, OmegaConf + +from anomalib.models.components import AnomalyModule +from anomalib.utils.path import convert_to_snake_case + +from .image import ( + Cfa, + Cflow, + Csflow, + Dfkde, + Dfm, + Draem, + Dsr, + EfficientAd, + Fastflow, + Ganomaly, + Padim, + Patchcore, + ReverseDistillation, + Rkde, + Stfpm, + Uflow, + WinClip, +) +from .video import AiVad + + +class UnknownModelError(ModuleNotFoundError): + pass + + +__all__ = [ + "Cfa", + "Cflow", + "Csflow", + "Dfkde", + "Dfm", + "Draem", + "Dsr", + "EfficientAd", + "Fastflow", + "Ganomaly", + "Padim", + "Patchcore", + "ReverseDistillation", + "Rkde", + "Stfpm", + "Uflow", + "AiVad", + "WinClip", +] + +logger = logging.getLogger(__name__) + + +def convert_snake_to_pascal_case(snake_case: str) -> str: + """Convert snake_case to PascalCase. + + Args: + snake_case (str): Input string in snake_case + + Returns: + str: Output string in PascalCase + + Examples: + >>> _convert_snake_to_pascal_case("efficient_ad") + EfficientAd + + >>> _convert_snake_to_pascal_case("patchcore") + Patchcore + """ + return "".join(word.capitalize() for word in snake_case.split("_")) + + +def get_available_models() -> set[str]: + """Get set of available models. + + Returns: + set[str]: List of available models. + + Example: + >>> get_available_models() + ['ai_vad', 'cfa', 'cflow', 'csflow', 'dfkde', 'dfm', 'draem', 'efficient_ad', 'fastflow', ...] + """ + return {convert_to_snake_case(cls.__name__) for cls in AnomalyModule.__subclasses__()} + + +def _get_model_class_by_name(name: str) -> type[AnomalyModule]: + """Retrieves an anomaly model based on its name. + + Args: + name (str): The name of the model to retrieve. The name is case insensitive. + + Raises: + UnknownModelError: If the model is not found. + + Returns: + type[AnomalyModule]: Anomaly Model + """ + logger.info("Loading the model.") + model_class: type[AnomalyModule] | None = None + + name = convert_snake_to_pascal_case(name).lower() + for model in AnomalyModule.__subclasses__(): + if name == model.__name__.lower(): + model_class = model + if model_class is None: + logger.exception(f"Could not find the model {name}. Available models are {get_available_models()}") + raise UnknownModelError + + return model_class + + +def get_model(model: DictConfig | str | dict | Namespace, *args, **kwdargs) -> AnomalyModule: + """Get Anomaly Model. + + Args: + model (DictConfig | str): Can either be a configuration or a string. + *args: Variable length argument list for model init. + **kwdargs: Arbitrary keyword arguments for model init. + + Examples: + >>> get_model("Padim") + >>> get_model("efficient_ad") + >>> get_model("Patchcore", input_size=(100, 100)) + >>> get_model({"class_path": "Padim"}) + >>> get_model({"class_path": "Patchcore"}, input_size=(100, 100)) + >>> get_model({"class_path": "Padim", "init_args": {"input_size": (100, 100)}}) + >>> get_model({"class_path": "anomalib.models.Padim", "init_args": {"input_size": (100, 100)}}}) + + Raises: + TypeError: If unsupported type is passed. + + Returns: + AnomalyModule: Anomaly Model + """ + _model: AnomalyModule + if isinstance(model, str): + _model_class = _get_model_class_by_name(model) + _model = _model_class(*args, **kwdargs) + elif isinstance(model, DictConfig | Namespace | dict): + if isinstance(model, dict): + model = OmegaConf.create(model) + try: + if len(model.class_path.split(".")) > 1: + module = import_module(".".join(model.class_path.split(".")[:-1])) + else: + module = import_module("anomalib.models") + except ModuleNotFoundError as exception: + logger.exception( + f"Could not find the module {model.class_path}. Available models are {get_available_models()}", + ) + raise UnknownModelError from exception + try: + model_class = getattr(module, model.class_path.split(".")[-1]) + init_args = model.get("init_args", {}) + if len(kwdargs) > 0: + init_args.update(kwdargs) + _model = model_class(*args, **init_args) + except AttributeError as exception: + logger.exception( + f"Could not find the model {model.class_path}. Available models are {get_available_models()}", + ) + raise UnknownModelError from exception + else: + logger.error(f"Unsupported type {type(model)} for model configuration.") + raise TypeError + return _model diff --git a/anomalib/models/components/__init__.py b/anomalib/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b37daafefe47e33a3b3a0c8238958cbfd8239d4f --- /dev/null +++ b/anomalib/models/components/__init__.py @@ -0,0 +1,26 @@ +"""Components used within the models.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .base import AnomalyModule, BufferListMixin, DynamicBufferMixin, MemoryBankMixin +from .dimensionality_reduction import PCA, SparseRandomProjection +from .feature_extractors import TimmFeatureExtractor, TorchFXFeatureExtractor +from .filters import GaussianBlur2d +from .sampling import KCenterGreedy +from .stats import GaussianKDE, MultiVariateGaussian + +__all__ = [ + "AnomalyModule", + "BufferListMixin", + "DynamicBufferMixin", + "MemoryBankMixin", + "GaussianKDE", + "GaussianBlur2d", + "KCenterGreedy", + "MultiVariateGaussian", + "PCA", + "SparseRandomProjection", + "TimmFeatureExtractor", + "TorchFXFeatureExtractor", +] diff --git a/anomalib/models/components/base/__init__.py b/anomalib/models/components/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b535c910cbf3cc37302b7b3919e1e9eff5c09eb8 --- /dev/null +++ b/anomalib/models/components/base/__init__.py @@ -0,0 +1,11 @@ +"""Base classes for all anomaly components.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .anomaly_module import AnomalyModule +from .buffer_list import BufferListMixin +from .dynamic_buffer import DynamicBufferMixin +from .memory_bank_module import MemoryBankMixin + +__all__ = ["AnomalyModule", "BufferListMixin", "DynamicBufferMixin", "MemoryBankMixin"] diff --git a/anomalib/models/components/base/anomaly_module.py b/anomalib/models/components/base/anomaly_module.py new file mode 100644 index 0000000000000000000000000000000000000000..303430074856607fa33ee5ec497188d3d32ecb02 --- /dev/null +++ b/anomalib/models/components/base/anomaly_module.py @@ -0,0 +1,278 @@ +"""Base Anomaly Module for Training Task.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import TYPE_CHECKING, Any + +import lightning.pytorch as pl +import torch +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn +from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform + +from anomalib import LearningType +from anomalib.metrics import AnomalibMetricCollection +from anomalib.metrics.threshold import BaseThreshold + +from .export_mixin import ExportMixin + +if TYPE_CHECKING: + from lightning.pytorch.callbacks import Callback + from torchmetrics import Metric + + +logger = logging.getLogger(__name__) + + +class AnomalyModule(ExportMixin, pl.LightningModule, ABC): + """AnomalyModule to train, validate, predict and test images. + + Acts as a base class for all the Anomaly Modules in the library. + """ + + def __init__(self) -> None: + super().__init__() + logger.info("Initializing %s model.", self.__class__.__name__) + + self.save_hyperparameters() + self.model: nn.Module + self.loss: nn.Module + self.callbacks: list[Callback] + + self.image_threshold: BaseThreshold + self.pixel_threshold: BaseThreshold + + self.normalization_metrics: Metric + + self.image_metrics: AnomalibMetricCollection + self.pixel_metrics: AnomalibMetricCollection + self.semantic_pixel_metrics: AnomalibMetricCollection + + self._transform: Transform | None = None + self._input_size: tuple[int, int] | None = None + + self._is_setup = False # flag to track if setup has been called from the trainer + + @property + def name(self) -> str: + """Name of the model.""" + return self.__class__.__name__ + + def setup(self, stage: str | None = None) -> None: + """Calls the _setup method to build the model if the model is not already built.""" + if getattr(self, "model", None) is None or not self._is_setup: + self._setup() + if isinstance(stage, TrainerFn): + # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer + self._is_setup = True + + def _setup(self) -> None: + """The _setup method is used to build the torch model dynamically or adjust something about them. + + The model implementer may override this method to build the model. This is useful when the model cannot be set + in the `__init__` method because it requires some information or data that is not available at the time of + initialization. + """ + + def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any: # noqa: ANN401 + """Perform the forward-pass by passing input tensor to the module. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch. + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + Tensor: Output tensor from the model. + """ + del args, kwargs # These variables are not used. + + return self.model(batch) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """To be implemented in the subclasses.""" + raise NotImplementedError + + def predict_step( + self, + batch: dict[str, str | torch.Tensor], + batch_idx: int, + dataloader_idx: int = 0, + ) -> STEP_OUTPUT: + """Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`. + + By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`. + Override to add any processing logic. + + Args: + batch (Any): Current batch + batch_idx (int): Index of current batch + dataloader_idx (int): Index of the current dataloader + + Return: + Predicted output + """ + del dataloader_idx # These variables are not used. + + return self.validation_step(batch, batch_idx) + + def test_step(self, batch: dict[str, str | torch.Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT: + """Calls validation_step for anomaly map/score calculation. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + batch_idx (int): Batch index + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Dictionary containing images, features, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + del args, kwargs # These variables are not used. + + return self.predict_step(batch, batch_idx) + + @property + @abstractmethod + def trainer_arguments(self) -> dict[str, Any]: + """Arguments used to override the trainer parameters so as to train the model correctly.""" + raise NotImplementedError + + def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars: bool) -> None: + if hasattr(self, "image_threshold"): + destination[ + "image_threshold_class" + ] = f"{self.image_threshold.__class__.__module__}.{self.image_threshold.__class__.__name__}" + if hasattr(self, "pixel_threshold"): + destination[ + "pixel_threshold_class" + ] = f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}" + if hasattr(self, "normalization_metrics"): + normalization_class = self.normalization_metrics.__class__ + destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}" + + return super()._save_to_state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401 + """Initialize auxiliary object.""" + if "image_threshold_class" in state_dict: + self.image_threshold = self._get_instance(state_dict, "image_threshold_class") + if "pixel_threshold_class" in state_dict: + self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class") + if "normalization_class" in state_dict: + self.normalization_metrics = self._get_instance(state_dict, "normalization_class") + # Used to load metrics if there is any related data in state_dict + self._load_metrics(state_dict) + + return super().load_state_dict(state_dict, strict) + + def _load_metrics(self, state_dict: OrderedDict[str, torch.Tensor]) -> None: + """Load metrics from saved checkpoint.""" + self._add_metrics("pixel", state_dict) + self._add_metrics("image", state_dict) + + def _add_metrics(self, name: str, state_dict: OrderedDict[str, torch.Tensor]) -> None: + """Sets the pixel/image metrics. + + Args: + name (str): is it pixel or image. + state_dict (OrderedDict[str, Tensor]): state dict of the model. + """ + metric_keys = [key for key in state_dict if key.startswith(f"{name}_metrics")] + if any(metric_keys): + if not hasattr(self, f"{name}_metrics"): + setattr(self, f"{name}_metrics", AnomalibMetricCollection([], prefix=f"{name}_")) + metrics = getattr(self, f"{name}_metrics") + for key in metric_keys: + class_name = key.split(".")[1] + try: + metrics_module = importlib.import_module("anomalib.metrics") + metrics_cls = getattr(metrics_module, class_name) + except (ImportError, AttributeError) as exception: + msg = f"Class {class_name} not found in module anomalib.metrics" + raise ImportError(msg) from exception + logger.info("Loading %s metrics from state dict", class_name) + metrics.add_metrics(metrics_cls()) + + def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> BaseThreshold: + """Get the threshold class from the ``state_dict``.""" + class_path = state_dict.pop(dict_key) + module = importlib.import_module(".".join(class_path.split(".")[:-1])) + return getattr(module, class_path.split(".")[-1])() + + @property + @abstractmethod + def learning_type(self) -> LearningType: + """Learning type of the model.""" + raise NotImplementedError + + @property + def transform(self) -> Transform: + """Retrieve the model-specific transform. + + If a transform has been set using `set_transform`, it will be returned. Otherwise, we will use the + model-specific default transform, conditioned on the input size. + """ + return self._transform + + def set_transform(self, transform: Transform) -> None: + """Update the transform linked to the model instance.""" + self._transform = transform + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Default transforms. + + The default transform is resize to 256x256 and normalize to ImageNet stats. Individual models can override + this method to provide custom transforms. + """ + logger.warning( + "No implementation of `configure_transforms` was provided in the Lightning model. Using default " + "transforms from the base class. This may not be suitable for your use case. Please override " + "`configure_transforms` in your model.", + ) + image_size = image_size or (256, 256) + return Compose( + [ + Resize(image_size, antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + ) + + @property + def input_size(self) -> tuple[int, int] | None: + """Return the effective input size of the model. + + The effective input size is the size of the input tensor after the transform has been applied. If the transform + is not set, or if the transform does not change the shape of the input tensor, this method will return None. + """ + transform = self.transform or self.configure_transforms() + if transform is None: + return None + dummy_input = torch.zeros(1, 3, 1, 1) + output_shape = transform(dummy_input).shape[-2:] + if output_shape == (1, 1): + return None + return output_shape[-2:] + + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Called when saving the model to a checkpoint. + + Saves the transform to the checkpoint. + """ + checkpoint["transform"] = self.transform + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Called when loading the model from a checkpoint. + + Loads the transform from the checkpoint and calls setup to ensure that the torch model is built before loading + the state dict. + """ + self._transform = checkpoint["transform"] + self.setup("load_checkpoint") diff --git a/anomalib/models/components/base/buffer_list.py b/anomalib/models/components/base/buffer_list.py new file mode 100644 index 0000000000000000000000000000000000000000..9fcd9c1c31a2ae04ef2fce98518e44630d577f87 --- /dev/null +++ b/anomalib/models/components/base/buffer_list.py @@ -0,0 +1,108 @@ +"""Buffer List Mixin.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + + +class BufferListMixin(nn.Module): + """Buffer List Mixin. + + This mixin is used to allow registering a list of tensors as buffers in a pytorch module. + + Example: + >>> class MyModule(BufferListMixin, nn.Module): + ... def __init__(self): + ... super().__init__() + ... tensor_list = [torch.ones(3) * i for i in range(3)] + ... self.register_buffer_list("my_buffer_list", tensor_list) + >>> module = MyModule() + >>> # The buffer list can be accessed as a regular attribute + >>> module.my_buffer_list + [ + tensor([0., 0., 0.]), + tensor([1., 1., 1.]), + tensor([2., 2., 2.]) + ] + >>> # We can update the buffer list at any time + >>> new_tensor_list = [torch.ones(3) * i + 10 for i in range(3)] + >>> module.register_buffer_list("my_buffer_list", new_tensor_list) + >>> module.my_buffer_list + [ + tensor([10., 10., 10.]), + tensor([11., 11., 11.]), + tensor([12., 12., 12.]) + ] + >>> # Move to GPU. Since the tensors are registered as buffers, device placement is handled automatically + >>> module.cuda() + >>> module.my_buffer_list + [ + tensor([10., 10., 10.], device='cuda:0'), + tensor([11., 11., 11.], device='cuda:0'), + tensor([12., 12., 12.], device='cuda:0') + ] + """ + + def register_buffer_list(self, name: str, values: list[torch.Tensor], persistent: bool = True, **kwargs) -> None: + """Register a list of tensors as buffers in a pytorch module. + + Each tensor is registered as a buffer with the name `_name_i` where `i` is the index of the tensor in the list. + To update and retrieve the list of tensors, we dynamically assign a descriptor attribute to the class. + + Args: + name (str): Name of the buffer list. + values (list[torch.Tensor]): List of tensors to register as buffers. + persistent (bool, optional): Whether the buffers should be saved as part of the module state_dict. + Defaults to True. + **kwargs: Additional keyword arguments to pass to `torch.nn.Module.register_buffer`. + """ + for i, value in enumerate(values): + self.register_buffer(f"_{name}_{i}", value, persistent=persistent, **kwargs) + + setattr(BufferListMixin, name, BufferListDescriptor(name, len(values))) + + +class BufferListDescriptor: + """Buffer List Descriptor. + + This descriptor is used to allow registering a list of tensors as buffers in a pytorch module. + + Args: + name (str): Name of the buffer list. + length (int): Length of the buffer list. + """ + + def __init__(self, name: str, length: int) -> None: + self.name = name + self.length = length + + def __get__(self, instance: object, object_type: type | None = None) -> list[torch.Tensor]: + """Get the list of tensors. + + Each element of the buffer list is stored as a buffer with the name `name_i` where `i` is the index of the + element in the list. We use list comprehension to retrieve the list of tensors. + + Args: + instance (object): Instance of the class. + object_type (Any, optional): Type of the class. Defaults to None. + + Returns: + list[torch.Tensor]: Contents of the buffer list. + """ + del object_type + return [getattr(instance, f"_{self.name}_{i}") for i in range(self.length)] + + def __set__(self, instance: object, values: list[torch.Tensor]) -> None: + """Set the list of tensors. + + Assigns a new list of tensors to the buffer list by updating the individual buffer attributes. + + Args: + instance (object): Instance of the class. + values (list[torch.Tensor]): List of tensors to set. + """ + for i, value in enumerate(values): + setattr(instance, f"_{self.name}_{i}", value) diff --git a/anomalib/models/components/base/dynamic_buffer.py b/anomalib/models/components/base/dynamic_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c6ad6bd64933d235b319619829e4dab3983fe5 --- /dev/null +++ b/anomalib/models/components/base/dynamic_buffer.py @@ -0,0 +1,57 @@ +"""Dynamic Buffer Mixin.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC + +import torch +from torch import nn + + +class DynamicBufferMixin(nn.Module, ABC): + """This mixin allows loading variables from the state dict even in the case of shape mismatch.""" + + def get_tensor_attribute(self, attribute_name: str) -> torch.Tensor: + """Get attribute of the tensor given the name. + + Args: + attribute_name (str): Name of the tensor + + Raises: + ValueError: `attribute_name` is not a torch Tensor + + Returns: + Tensor: torch.Tensor attribute + """ + attribute = getattr(self, attribute_name) + if isinstance(attribute, torch.Tensor): + return attribute + + msg = f"Attribute with name '{attribute_name}' is not a torch Tensor" + raise ValueError(msg) + + def _load_from_state_dict(self, state_dict: dict, prefix: str, *args) -> None: + """Resizes the local buffers to match those stored in the state dict. + + Overrides method from parent class. + + Args: + state_dict (dict): State dictionary containing weights + prefix (str): Prefix of the weight file. + *args: Variable length argument list. + """ + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_buffers = {k: v for k, v in persistent_buffers.items() if v is not None} + + for param in local_buffers: + for key in state_dict: + if ( + key.startswith(prefix) + and key[len(prefix) :].split(".")[0] == param + and local_buffers[param].shape != state_dict[key].shape + ): + attribute = self.get_tensor_attribute(param) + attribute.resize_(state_dict[key].shape) + + super()._load_from_state_dict(state_dict, prefix, *args) diff --git a/anomalib/models/components/base/export_mixin.py b/anomalib/models/components/base/export_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ac4449ef0febdbd1a9d3dd73408d25250b9907 --- /dev/null +++ b/anomalib/models/components/base/export_mixin.py @@ -0,0 +1,292 @@ +"""Mixin for exporting models to disk.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import json +import logging +from collections.abc import Callable +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.deploy.export import ExportType, InferenceModel +from anomalib.utils.exceptions import try_import + +if TYPE_CHECKING: + from torch.types import Number + +logger = logging.getLogger(__name__) + + +class ExportMixin: + """This mixin allows exporting models to torch and ONNX/OpenVINO.""" + + model: nn.Module + transform: Transform + configure_transforms: Callable + device: torch.device + + def to_torch( + self, + export_root: Path | str, + transform: Transform | None = None, + task: TaskType | None = None, + ) -> Path: + """Export AnomalibModel to torch. + + Args: + export_root (Path): Path to the output folder. + transform (Transform, optional): Input transforms used for the model. If not provided, the transform is + taken from the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + Path: Path to the exported pytorch model. + + Examples: + Assume that we have a model to train and we want to export it to torch format. + + >>> from anomalib.data import Visa + >>> from anomalib.models import Patchcore + >>> from anomalib.engine import Engine + ... + >>> datamodule = Visa() + >>> model = Patchcore() + >>> engine = Engine() + ... + >>> engine.fit(model, datamodule) + + Now that we have a model trained, we can export it to torch format. + + >>> model.to_torch( + ... export_root="path/to/export", + ... transform=datamodule.test_data.transform, + ... task=datamodule.test_data.task, + ... ) + """ + transform = transform or self.transform or self.configure_transforms() + inference_model = InferenceModel(model=self.model, transform=transform) + export_root = _create_export_root(export_root, ExportType.TORCH) + metadata = self._get_metadata(task=task) + pt_model_path = export_root / "model.pt" + torch.save( + obj={"model": inference_model, "metadata": metadata}, + f=pt_model_path, + ) + return pt_model_path + + def to_onnx( + self, + export_root: Path | str, + input_size: tuple[int, int] | None = None, + transform: Transform | None = None, + task: TaskType | None = None, + ) -> Path: + """Export model to onnx. + + Args: + export_root (Path): Path to the root folder of the exported model. + input_size (tuple[int, int] | None, optional): Image size used as the input for onnx converter. + Defaults to None. + transform (Transform, optional): Input transforms used for the model. If not provided, the transform is + taken from the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + Path: Path to the exported onnx model. + + Examples: + Export the Lightning Model to ONNX: + + >>> from anomalib.models import Patchcore + >>> from anomalib.data import Visa + ... + >>> datamodule = Visa() + >>> model = Patchcore() + ... + >>> model.to_onnx( + ... export_root="path/to/export", + ... transform=datamodule.test_data.transform, + ... task=datamodule.test_data.task + ... ) + + Using Custom Transforms: + This example shows how to use a custom ``Compose`` object for the ``transform`` argument. + + >>> model.to_onnx( + ... export_root="path/to/export", + ... task="segmentation", + ... ) + """ + transform = transform or self.transform or self.configure_transforms() + inference_model = InferenceModel(model=self.model, transform=transform, disable_antialias=True) + export_root = _create_export_root(export_root, ExportType.ONNX) + input_shape = torch.zeros((1, 3, *input_size)) if input_size else torch.zeros((1, 3, 1, 1)) + dynamic_axes = ( + None if input_size else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}} + ) + _write_metadata_to_json(self._get_metadata(task), export_root) + onnx_path = export_root / "model.onnx" + torch.onnx.export( + inference_model, + input_shape.to(self.device), + str(onnx_path), + opset_version=14, + dynamic_axes=dynamic_axes, + input_names=["input"], + output_names=["output"], + ) + + return onnx_path + + def to_openvino( + self, + export_root: Path | str, + input_size: tuple[int, int] | None = None, + transform: Transform | None = None, + ov_args: dict[str, Any] | None = None, + task: TaskType | None = None, + ) -> Path: + """Convert onnx model to OpenVINO IR. + + Args: + export_root (Path): Path to the export folder. + input_size (tuple[int, int] | None, optional): Input size of the model. Used for adding metadata to the IR. + Defaults to None. + transform (Transform, optional): Input transforms used for the model. If not provided, the transform is + taken from the model. + Defaults to ``None``. + ov_args: Model optimizer arguments for OpenVINO model conversion. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + Path: Path to the exported onnx model. + + Raises: + ModuleNotFoundError: If OpenVINO is not installed. + + Returns: + Path: Path to the exported OpenVINO IR. + + Examples: + Export the Lightning Model to OpenVINO IR: + This example demonstrates how to export the Lightning Model to OpenVINO IR. + + >>> from anomalib.models import Patchcore + >>> from anomalib.data import Visa + ... + >>> datamodule = Visa() + >>> model = Patchcore() + ... + >>> model.to_openvino( + ... export_root="path/to/export", + ... transform=datamodule.test_data.transform, + ... task=datamodule.test_data.task + ... ) + + Using Custom Transforms: + This example shows how to use a custom ``Transform`` object for the ``transform`` argument. + + >>> from torchvision.transforms.v2 import Resize + >>> transform = Resize(224, 224) + ... + >>> model.to_openvino( + ... export_root="path/to/export", + ... transform=transform, + ... task="segmentation", + ... ) + """ + if not try_import("openvino"): + logger.exception("Could not find OpenVINO. Please check OpenVINO installation.") + raise ModuleNotFoundError + + import openvino as ov + + with TemporaryDirectory() as onnx_directory: + model_path = self.to_onnx(onnx_directory, input_size, transform, task) + export_root = _create_export_root(export_root, ExportType.OPENVINO) + ov_model_path = export_root / "model.xml" + ov_args = {} if ov_args is None else ov_args + # fp16 compression is enabled by default + compress_to_fp16 = ov_args.get("compress_to_fp16", True) + + model = ov.convert_model(model_path, **ov_args) + ov.save_model(model, ov_model_path, compress_to_fp16=compress_to_fp16) + _write_metadata_to_json(self._get_metadata(task), export_root) + + return ov_model_path + + def _get_metadata( + self, + task: TaskType | None = None, + ) -> dict[str, Any]: + """Get metadata for the exported model. + + Args: + task (TaskType | None): Task type. + Defaults to None. + + Returns: + dict[str, Any]: Metadata for the exported model. + """ + model_metadata = {} + cached_metadata: dict[str, Number | torch.Tensor] = {} + for threshold_name in ("image_threshold", "pixel_threshold"): + if hasattr(self, threshold_name): + cached_metadata[threshold_name] = getattr(self, threshold_name).cpu().value.item() + if hasattr(self, "normalization_metrics") and self.normalization_metrics.state_dict() is not None: + for key, value in self.normalization_metrics.state_dict().items(): + cached_metadata[key] = value.cpu() + # Remove undefined values by copying in a new dict + for key, val in cached_metadata.items(): + if not np.isinf(val).all(): + model_metadata[key] = val + del cached_metadata + metadata = {"task": task, **model_metadata} + + # Convert torch tensors to python lists or values for json serialization. + for key, value in metadata.items(): + if isinstance(value, torch.Tensor): + metadata[key] = value.numpy().tolist() + + return metadata + + +def _write_metadata_to_json(metadata: dict[str, Any], export_root: Path) -> None: + """Write metadata to json file. + + Args: + metadata (dict[str, Any]): Metadata to export. + export_root (Path): Path to the exported model. + """ + with (export_root / "metadata.json").open("w", encoding="utf-8") as metadata_file: + json.dump(metadata, metadata_file, ensure_ascii=False, indent=4) + + +def _create_export_root(export_root: str | Path, export_type: ExportType) -> Path: + """Create export directory. + + Args: + export_root (str | Path): Path to the root folder of the exported model. + export_type (ExportType): Mode to export the model. Torch, ONNX or OpenVINO. + + Returns: + Path: Path to the export directory. + """ + export_root = Path(export_root) / "weights" / export_type.value + export_root.mkdir(parents=True, exist_ok=True) + return export_root diff --git a/anomalib/models/components/base/memory_bank_module.py b/anomalib/models/components/base/memory_bank_module.py new file mode 100644 index 0000000000000000000000000000000000000000..44fd507ab6c49a0e09a551c4301c2f679f1c64aa --- /dev/null +++ b/anomalib/models/components/base/memory_bank_module.py @@ -0,0 +1,44 @@ +"""Memory Bank Module.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from abc import abstractmethod + +import torch +from torch import nn + + +class MemoryBankMixin(nn.Module): + """Memory Bank Lightning Module. + + This module is used to implement memory bank lightning modules. + It checks if the model is fitted before validation starts. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_buffer("_is_fitted", torch.tensor([False])) + self._is_fitted: torch.Tensor + + @abstractmethod + def fit(self) -> None: + """Fit the model to the data.""" + msg = ( + f"fit method not implemented for {self.__class__.__name__}. " + "To use a memory-bank module, implement ``fit.``" + ) + raise NotImplementedError(msg) + + def on_validation_start(self) -> None: + """Ensure that the model is fitted before validation starts.""" + if not self._is_fitted: + self.fit() + self._is_fitted = torch.tensor([True]) + + def on_train_epoch_end(self) -> None: + """Ensure that the model is fitted before validation starts.""" + if not self._is_fitted: + self.fit() + self._is_fitted = torch.tensor([True]) diff --git a/anomalib/models/components/classification/__init__.py b/anomalib/models/components/classification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7767cb466e0e3bf0ab28b974a1f275c3482b51b1 --- /dev/null +++ b/anomalib/models/components/classification/__init__.py @@ -0,0 +1,5 @@ +"""Classification modules.""" + +from .kde_classifier import FeatureScalingMethod, KDEClassifier + +__all__ = ["KDEClassifier", "FeatureScalingMethod"] diff --git a/anomalib/models/components/classification/kde_classifier.py b/anomalib/models/components/classification/kde_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9193a635e1bb4811f087f595a504a277ddde67 --- /dev/null +++ b/anomalib/models/components/classification/kde_classifier.py @@ -0,0 +1,161 @@ +"""Kernel Density Estimation Classifier.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import random +from enum import Enum + +import torch +from torch import nn + +from anomalib.models.components import PCA, GaussianKDE + +logger = logging.getLogger(__name__) + + +class FeatureScalingMethod(str, Enum): + """Determines how the feature embeddings are scaled.""" + + NORM = "norm" # scale to unit vector length + SCALE = "scale" # scale to max length observed in training (preserve relative magnitude) + + +class KDEClassifier(nn.Module): + """Classification module for KDE-based anomaly detection. + + Args: + n_pca_components (int, optional): Number of PCA components. Defaults to 16. + feature_scaling_method (FeatureScalingMethod, optional): Scaling method applied to features before passing to + KDE. Options are `norm` (normalize to unit vector length) and `scale` (scale to max length observed in + training). + max_training_points (int, optional): Maximum number of training points to fit the KDE model. Defaults to 40000. + """ + + def __init__( + self, + n_pca_components: int = 16, + feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE, + max_training_points: int = 40000, + ) -> None: + super().__init__() + + self.n_pca_components = n_pca_components + self.feature_scaling_method = feature_scaling_method + self.max_training_points = max_training_points + + self.pca_model = PCA(n_components=self.n_pca_components) + self.kde_model = GaussianKDE() + + self.register_buffer("max_length", torch.empty([])) + self.max_length = torch.empty([]) + + def pre_process( + self, + feature_stack: torch.Tensor, + max_length: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Pre-process the CNN features. + + Args: + feature_stack (torch.Tensor): Features extracted from CNN + max_length (Tensor | None): Used to unit normalize the feature_stack vector. If ``max_len`` is not + provided, the length is calculated from the ``feature_stack``. Defaults to None. + + Returns: + (Tuple): Stacked features and length + """ + if max_length is None: + max_length = torch.max(torch.linalg.norm(feature_stack, ord=2, dim=1)) + + if self.feature_scaling_method == FeatureScalingMethod.NORM: + feature_stack /= torch.linalg.norm(feature_stack, ord=2, dim=1)[:, None] + elif self.feature_scaling_method == FeatureScalingMethod.SCALE: + feature_stack /= max_length + else: + msg = "Unknown pre-processing mode. Available modes are: Normalized and Scale." + raise RuntimeError(msg) + return feature_stack, max_length + + def fit(self, embeddings: torch.Tensor) -> bool: + """Fit a kde model to embeddings. + + Args: + embeddings (torch.Tensor): Input embeddings to fit the model. + + Returns: + Boolean confirming whether the training is successful. + """ + if embeddings.shape[0] < self.n_pca_components: + logger.info("Not enough features to commit. Not making a model.") + return False + + # if max training points is non-zero and smaller than number of staged features, select random subset + if embeddings.shape[0] > self.max_training_points: + selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points)) + selected_features = embeddings[selected_idx] + else: + selected_features = embeddings + + feature_stack = self.pca_model.fit_transform(selected_features) + feature_stack, max_length = self.pre_process(feature_stack) + self.max_length = max_length + self.kde_model.fit(feature_stack) + + return True + + def compute_kde_scores(self, features: torch.Tensor, as_log_likelihood: bool | None = False) -> torch.Tensor: + """Compute the KDE scores. + + The scores calculated from the KDE model are converted to densities. If `as_log_likelihood` is set to true then + the log of the scores are calculated. + + Args: + features (torch.Tensor): Features to which the PCA model is fit. + as_log_likelihood (bool | None, optional): If true, gets log likelihood scores. Defaults to False. + + Returns: + (torch.Tensor): Score + """ + features = self.pca_model.transform(features) + features, _ = self.pre_process(features, self.max_length) + # Scores are always assumed to be passed as a density + kde_scores = self.kde_model(features) + + # add small constant to avoid zero division in log computation + kde_scores += 1e-300 + + if as_log_likelihood: + kde_scores = torch.log(kde_scores) + + return kde_scores + + @staticmethod + def compute_probabilities(scores: torch.Tensor) -> torch.Tensor: + """Convert density scores to anomaly probabilities (see https://www.desmos.com/calculator/ifju7eesg7). + + Args: + scores (torch.Tensor): density of an image. + + Returns: + probability that image with {density} is anomalous + """ + return 1 / (1 + torch.exp(0.05 * (scores - 12))) + + def predict(self, features: torch.Tensor) -> torch.Tensor: + """Predicts the probability that the features belong to the anomalous class. + + Args: + features (torch.Tensor): Feature from which the output probabilities are detected. + + Returns: + Detection probabilities + """ + scores = self.compute_kde_scores(features, as_log_likelihood=True) + return self.compute_probabilities(scores) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Make predictions on extracted features.""" + return self.predict(features) diff --git a/anomalib/models/components/cluster/__init__.py b/anomalib/models/components/cluster/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ce0455af2e9a4621c6dbf8c31afb0820213f8b --- /dev/null +++ b/anomalib/models/components/cluster/__init__.py @@ -0,0 +1,9 @@ +"""Clustering algorithm implementations using PyTorch.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .gmm import GaussianMixture +from .kmeans import KMeans + +__all__ = ["GaussianMixture", "KMeans"] diff --git a/anomalib/models/components/cluster/gmm.py b/anomalib/models/components/cluster/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f94693b2797d907b72b23a477141543c10df38 --- /dev/null +++ b/anomalib/models/components/cluster/gmm.py @@ -0,0 +1,176 @@ +"""Pytorch implementation of Gaussian Mixture Model.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torch.distributions.multivariate_normal import MultivariateNormal +from torch.nn.functional import one_hot + +from anomalib.models.components.base import DynamicBufferMixin +from anomalib.models.components.cluster.kmeans import KMeans + +logger = logging.getLogger(__name__) + + +class GaussianMixture(DynamicBufferMixin): + """Gaussian Mixture Model. + + Args: + n_components (int): Number of components. + n_iter (int): Maximum number of iterations to perform. + Defaults to ``100``. + tol (float): Convergence threshold. + Defaults to ``1e-3``. + + Example: + The following examples shows how to fit a Gaussian Mixture Model to some data and get the cluster means and + predicted labels and log-likelihood scores of the data. + + .. code-block:: python + + >>> import torch + >>> from anomalib.models.components.cluster import GaussianMixture + >>> model = GaussianMixture(n_components=2) + >>> data = torch.tensor( + ... [ + ... [2, 1], [2, 2], [2, 3], + ... [7, 5], [8, 5], [9, 5], + ... ] + ... ).float() + >>> model.fit(data) + >>> model.means # get the means of the gaussians + tensor([[8., 5.], + [2., 2.]]) + >>> model.predict(data) # get the predicted cluster label of each sample + tensor([1, 1, 1, 0, 0, 0]) + >>> model.score_samples(data) # get the log-likelihood score of each sample + tensor([3.8295, 4.5795, 3.8295, 3.8295, 4.5795, 3.8295]) + """ + + def __init__(self, n_components: int, n_iter: int = 100, tol: float = 1e-3) -> None: + super().__init__() + self.n_components = n_components + self.tol = tol + self.n_iter = n_iter + + self.register_buffer("means", torch.Tensor()) + self.register_buffer("covariances", torch.Tensor()) + self.register_buffer("weights", torch.Tensor()) + + self.means: torch.Tensor + self.covariances: torch.Tensor + self.weights: torch.Tensor + + def fit(self, data: torch.Tensor) -> None: + """Fit the model to the data. + + Args: + data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + """ + self._initialize_parameters_kmeans(data) + + log_likelihood_old = 0 + converged = False + for _ in range(self.n_iter): + # E-step + log_likelihood_new, resp = self._e_step(data) + # M-step + self._m_step(data, resp) + + # Check for convergence + if torch.abs(log_likelihood_new - log_likelihood_old) < self.tol: + converged = True + break + log_likelihood_old = log_likelihood_new + + if not converged: + logger.warning( + f"GMM did not converge after {self.n_iter} iterations. \ + Consider increasing the number of iterations.", + ) + + def _initialize_parameters_kmeans(self, data: torch.Tensor) -> None: + """Initialize parameters with K-means. + + Args: + data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + """ + labels, _ = KMeans(n_clusters=self.n_components).fit(data) + resp = one_hot(labels, num_classes=self.n_components).float() + self._m_step(data, resp) + + def _e_step(self, data: torch.Tensor) -> torch.Tensor: + """Perform the E-step to estimate the responsibilities of the gaussians. + + Args: + data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + + Returns: + Tensor: log probability of the data given the gaussians. + Tensor: Tensor of shape (n_samples, n_components) containing the responsibilities. + """ + weighted_log_prob = self._estimate_weighted_log_prob(data) + log_prob_norm = torch.logsumexp(weighted_log_prob, axis=1) + log_resp = weighted_log_prob - torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) + return torch.mean(log_prob_norm), torch.exp(log_resp) + + def _m_step(self, data: torch.Tensor, resp: torch.Tensor) -> None: + """Perform the M-step to update the parameters of the gaussians. + + Args: + data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + resp (Tensor): Tensor of shape (n_samples, n_components) containing the responsibilities. + """ + cluster_counts = resp.sum(axis=0) # number of points in each cluster + self.weights = resp.mean(axis=0) # new weights + self.means = (resp.T @ data) / cluster_counts[:, None] # new means + + diff = data.unsqueeze(0) - self.means.unsqueeze(1) + weighted_diff = diff * resp.T.unsqueeze(-1) + covariances = torch.bmm(weighted_diff.transpose(-2, -1), diff) / cluster_counts.view(-1, 1, 1) + # Add a small constant for numerical stability + self.covariances = covariances + torch.eye(data.shape[1], device=data.device) * 1e-6 # new covariances + + def _estimate_weighted_log_prob(self, data: torch.Tensor) -> torch.Tensor: + """Estimate the log probability of the data given the gaussian parameters. + + Args: + data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + + Returns: + Tensor: Tensor of shape (n_samples, n_components) containing the log-probabilities of each sample. + """ + log_prob = torch.stack( + [ + MultivariateNormal(self.means[comp], self.covariances[comp]).log_prob(data) + for comp in range(self.n_components) + ], + dim=1, + ) + return log_prob + torch.log(self.weights) + + def score_samples(self, data: torch.Tensor) -> torch.Tensor: + """Assign a likelihood score to each sample in the data. + + Args: + data (Tensor): Samples to assign scores to. Tensor of shape (n_samples, n_features). + + Returns: + Tensor: Tensor of shape (n_samples,) containing the log-likelihood score of each sample. + """ + return torch.logsumexp(self._estimate_weighted_log_prob(data), dim=1) + + def predict(self, data: torch.Tensor) -> torch.Tensor: + """Predict the cluster labels of the data. + + Args: + data (Tensor): Samples to assign to clusters. Tensor of shape (n_samples, n_features). + + Returns: + Tensor: Tensor of shape (n_samples,) containing the predicted cluster label of each sample. + """ + _, resp = self._e_step(data) + return torch.argmax(resp, axis=1) diff --git a/anomalib/models/components/cluster/kmeans.py b/anomalib/models/components/cluster/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..908a3e3faeea4ee570de1d88f3862af809d92ec3 --- /dev/null +++ b/anomalib/models/components/cluster/kmeans.py @@ -0,0 +1,70 @@ +"""KMeans clustering algorithm implementation using PyTorch.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class KMeans: + """Initialize the KMeans object. + + Args: + n_clusters (int): The number of clusters to create. + max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10. + """ + + def __init__(self, n_clusters: int, max_iter: int = 10) -> None: + self.n_clusters = n_clusters + self.max_iter = max_iter + + def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Fit the K-means algorithm to the input data. + + Args: + inputs (torch.Tensor): Input data of shape (batch_size, n_features). + + Returns: + tuple: A tuple containing the labels of the input data with respect to the identified clusters + and the cluster centers themselves. The labels have a shape of (batch_size,) and the + cluster centers have a shape of (n_clusters, n_features). + + Raises: + ValueError: If the number of clusters is less than or equal to 0. + """ + batch_size, _ = inputs.shape + + # Initialize centroids randomly from the data points + centroid_indices = torch.randint(0, batch_size, (self.n_clusters,)) + self.cluster_centers_ = inputs[centroid_indices] + + # Run the k-means algorithm for max_iter iterations + for _ in range(self.max_iter): + # Compute the distance between each data point and each centroid + distances = torch.cdist(inputs, self.cluster_centers_) + + # Assign each data point to the closest centroid + self.labels_ = torch.argmin(distances, dim=1) + + # Update the centroids to be the mean of the data points assigned to them + for j in range(self.n_clusters): + mask = self.labels_ == j + if mask.any(): + self.cluster_centers_[j] = inputs[mask].mean(dim=0) + # this line returns labels and centoids of the results + return self.labels_, self.cluster_centers_ + + def predict(self, inputs: torch.Tensor) -> torch.Tensor: + """Predict the labels of input data based on the fitted model. + + Args: + inputs (torch.Tensor): Input data of shape (batch_size, n_features). + + Returns: + torch.Tensor: The predicted labels of the input data with respect to the identified clusters. + + Raises: + AttributeError: If the KMeans object has not been fitted to input data. + """ + distances = torch.cdist(inputs, self.cluster_centers_) + return torch.argmin(distances, dim=1) diff --git a/anomalib/models/components/dimensionality_reduction/__init__.py b/anomalib/models/components/dimensionality_reduction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d69c691bf0165c337c3273fe228ded823b9e32c8 --- /dev/null +++ b/anomalib/models/components/dimensionality_reduction/__init__.py @@ -0,0 +1,9 @@ +"""Algorithms for decomposition and dimensionality reduction.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .pca import PCA +from .random_projection import SparseRandomProjection + +__all__ = ["PCA", "SparseRandomProjection"] diff --git a/anomalib/models/components/dimensionality_reduction/pca.py b/anomalib/models/components/dimensionality_reduction/pca.py new file mode 100644 index 0000000000000000000000000000000000000000..ea64975cd9d2b63264efbc256a5ce5c8e1ac941e --- /dev/null +++ b/anomalib/models/components/dimensionality_reduction/pca.py @@ -0,0 +1,161 @@ +"""Principle Component Analysis (PCA) with PyTorch.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +from anomalib.models.components.base import DynamicBufferMixin + + +class PCA(DynamicBufferMixin): + """Principle Component Analysis (PCA). + + Args: + n_components (float): Number of components. Can be either integer number of components + or a ratio between 0-1. + + Example: + >>> import torch + >>> from anomalib.models.components import PCA + + Create a PCA model with 2 components: + + >>> pca = PCA(n_components=2) + + Create a random embedding and fit a PCA model. + + >>> embedding = torch.rand(1000, 5).cuda() + >>> pca = PCA(n_components=2) + >>> pca.fit(embedding) + + Apply transformation: + + >>> transformed = pca.transform(embedding) + >>> transformed.shape + torch.Size([1000, 2]) + """ + + def __init__(self, n_components: int | float) -> None: + super().__init__() + self.n_components = n_components + + self.register_buffer("singular_vectors", torch.Tensor()) + self.register_buffer("mean", torch.Tensor()) + self.register_buffer("num_components", torch.Tensor()) + + self.singular_vectors: torch.Tensor + self.singular_values: torch.Tensor + self.mean: torch.Tensor + self.num_components: torch.Tensor + + def fit(self, dataset: torch.Tensor) -> None: + """Fits the PCA model to the dataset. + + Args: + dataset (torch.Tensor): Input dataset to fit the model. + + Example: + >>> pca.fit(embedding) + >>> pca.singular_vectors + tensor([9.6053, 9.2763], device='cuda:0') + + >>> pca.mean + tensor([0.4859, 0.4959, 0.4906, 0.5010, 0.5042], device='cuda:0') + """ + mean = dataset.mean(dim=0) + dataset -= mean + + _, sig, v_h = torch.linalg.svd(dataset.double(), full_matrices=False) + num_components: int + if self.n_components <= 1: + variance_ratios = torch.cumsum(sig * sig, dim=0) / torch.sum(sig * sig) + num_components = torch.nonzero(variance_ratios >= self.n_components)[0] + else: + num_components = int(self.n_components) + + self.num_components = torch.Tensor([num_components]) + + self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float() + self.singular_values = sig[:num_components].float() + self.mean = mean + + def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor: + """Fit and transform PCA to dataset. + + Args: + dataset (torch.Tensor): Dataset to which the PCA if fit and transformed + + Returns: + Transformed dataset + + Example: + >>> pca.fit_transform(embedding) + >>> transformed_embedding = pca.fit_transform(embedding) + >>> transformed_embedding.shape + torch.Size([1000, 2]) + """ + mean = dataset.mean(dim=0) + dataset -= mean + num_components = int(self.n_components) + self.num_components = torch.Tensor([num_components]) + + v_h = torch.linalg.svd(dataset)[-1] + self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components] + self.mean = mean + + return torch.matmul(dataset, self.singular_vectors) + + def transform(self, features: torch.Tensor) -> torch.Tensor: + """Transform the features based on singular vectors calculated earlier. + + Args: + features (torch.Tensor): Input features + + Returns: + Transformed features + + Example: + >>> pca.transform(embedding) + >>> transformed_embedding = pca.transform(embedding) + + >>> embedding.shape + torch.Size([1000, 5]) + # + >>> transformed_embedding.shape + torch.Size([1000, 2]) + """ + features -= self.mean + return torch.matmul(features, self.singular_vectors) + + def inverse_transform(self, features: torch.Tensor) -> torch.Tensor: + """Inverses the transformed features. + + Args: + features (torch.Tensor): Transformed features + + Returns: + Inverse features + + Example: + >>> inverse_embedding = pca.inverse_transform(transformed_embedding) + >>> inverse_embedding.shape + torch.Size([1000, 5]) + """ + return torch.matmul(features, self.singular_vectors.transpose(-2, -1)) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Transform the features. + + Args: + features (torch.Tensor): Input features + + Returns: + Transformed features + + Example: + >>> pca(embedding).shape + torch.Size([1000, 2]) + """ + return self.transform(features) diff --git a/anomalib/models/components/dimensionality_reduction/random_projection.py b/anomalib/models/components/dimensionality_reduction/random_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..4a684d77b3781c44b9c0f925de903863f750b105 --- /dev/null +++ b/anomalib/models/components/dimensionality_reduction/random_projection.py @@ -0,0 +1,159 @@ +"""Random Sparse Projector. + +Sparse Random Projection using PyTorch Operations +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import torch +from sklearn.utils.random import sample_without_replacement + + +class NotFittedError(ValueError, AttributeError): + """Raise Exception if estimator is used before fitting.""" + + +class SparseRandomProjection: + """Sparse Random Projection using PyTorch operations. + + Args: + eps (float, optional): Minimum distortion rate parameter for calculating + Johnson-Lindenstrauss minimum dimensions. + Defaults to ``0.1``. + random_state (int | None, optional): Uses the seed to set the random + state for sample_without_replacement function. + Defaults to ``None``. + + Example: + To fit and transform the embedding tensor, use the following code: + + .. code-block:: python + + import torch + from anomalib.models.components import SparseRandomProjection + + sparse_embedding = torch.rand(1000, 5).cuda() + model = SparseRandomProjection(eps=0.1) + + Fit the model and transform the embedding tensor: + + .. code-block:: python + + model.fit(sparse_embedding) + projected_embedding = model.transform(sparse_embedding) + + print(projected_embedding.shape) + # Output: torch.Size([1000, 5920]) + """ + + def __init__(self, eps: float = 0.1, random_state: int | None = None) -> None: + self.n_components: int + self.sparse_random_matrix: torch.Tensor + self.eps = eps + self.random_state = random_state + + def _sparse_random_matrix(self, n_features: int) -> torch.Tensor: + """Random sparse matrix. Based on https://web.stanford.edu/~hastie/Papers/Ping/KDD06_rp.pdf. + + Args: + n_features (int): Dimentionality of the original source space + + Returns: + Tensor: Sparse matrix of shape (n_components, n_features). + The generated Gaussian random matrix is in CSR (compressed sparse row) + format. + """ + # Density 'auto'. Factorize density + density = 1 / np.sqrt(n_features) + + if density == 1: + # skip index generation if totally dense + binomial = torch.distributions.Binomial(total_count=1, probs=0.5) + components = binomial.sample((self.n_components, n_features)) * 2 - 1 + components = 1 / np.sqrt(self.n_components) * components + + else: + # Sparse matrix is not being generated here as it is stored as dense anyways + components = torch.zeros((self.n_components, n_features), dtype=torch.float32) + for i in range(self.n_components): + # find the indices of the non-zero components for row i + nnz_idx = torch.distributions.Binomial(total_count=n_features, probs=density).sample() + # get nnz_idx column indices + # pylint: disable=not-callable + c_idx = torch.tensor( + sample_without_replacement( + n_population=n_features, + n_samples=nnz_idx, + random_state=self.random_state, + ), + dtype=torch.int32, + ) + data = torch.distributions.Binomial(total_count=1, probs=0.5).sample(sample_shape=c_idx.size()) * 2 - 1 + # assign data to only those columns + components[i, c_idx] = data + + components *= np.sqrt(1 / density) / np.sqrt(self.n_components) + + return components + + def _johnson_lindenstrauss_min_dim(self, n_samples: int, eps: float = 0.1) -> int | np.integer: + """Find a 'safe' number of components to randomly project to. + + Ref eqn 2.1 https://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf + + Args: + n_samples (int): Number of samples used to compute safe components + eps (float, optional): Minimum distortion rate. Defaults to 0.1. + """ + denominator = (eps**2 / 2) - (eps**3 / 3) + return (4 * np.log(n_samples) / denominator).astype(np.int64) + + def fit(self, embedding: torch.Tensor) -> "SparseRandomProjection": + """Generate sparse matrix from the embedding tensor. + + Args: + embedding (torch.Tensor): embedding tensor for generating embedding + + Returns: + (SparseRandomProjection): Return self to be used as + + >>> model = SparseRandomProjection() + >>> model = model.fit() + """ + n_samples, n_features = embedding.shape + device = embedding.device + + self.n_components = self._johnson_lindenstrauss_min_dim(n_samples=n_samples, eps=self.eps) + + # Generate projection matrix + # torch can't multiply directly on sparse matrix and moving sparse matrix to cuda throws error + # (Could not run 'aten::empty_strided' with arguments from the 'SparseCsrCUDA' backend) + # hence sparse matrix is stored as a dense matrix on the device + self.sparse_random_matrix = self._sparse_random_matrix(n_features=n_features).to(device) + + return self + + def transform(self, embedding: torch.Tensor) -> torch.Tensor: + """Project the data by using matrix product with the random matrix. + + Args: + embedding (torch.Tensor): Embedding of shape (n_samples, n_features) + The input data to project into a smaller dimensional space + + Returns: + projected_embedding (torch.Tensor): Sparse matrix of shape + (n_samples, n_components) Projected array. + + Example: + >>> projected_embedding = model.transform(embedding) + >>> projected_embedding.shape + torch.Size([1000, 5920]) + """ + if self.sparse_random_matrix is None: + msg = "`fit()` has not been called on SparseRandomProjection yet." + raise NotFittedError(msg) + + return embedding @ self.sparse_random_matrix.T.float() diff --git a/anomalib/models/components/feature_extractors/__init__.py b/anomalib/models/components/feature_extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5092056967c531388988f4dda29f7d70432aecf1 --- /dev/null +++ b/anomalib/models/components/feature_extractors/__init__.py @@ -0,0 +1,15 @@ +"""Feature extractors.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .timm import TimmFeatureExtractor +from .torchfx import BackboneParams, TorchFXFeatureExtractor +from .utils import dryrun_find_featuremap_dims + +__all__ = [ + "BackboneParams", + "dryrun_find_featuremap_dims", + "TimmFeatureExtractor", + "TorchFXFeatureExtractor", +] diff --git a/anomalib/models/components/feature_extractors/timm.py b/anomalib/models/components/feature_extractors/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..ae81dfb2c433cf9f0313acff40e4e802028bf7ce --- /dev/null +++ b/anomalib/models/components/feature_extractors/timm.py @@ -0,0 +1,129 @@ +"""Feature Extractor. + +This script extracts features from a CNN network +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Sequence + +import timm +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +class TimmFeatureExtractor(nn.Module): + """Extract features from a CNN. + + Args: + backbone (nn.Module): The backbone to which the feature extraction hooks are attached. + layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + pre_trained (bool): Whether to use a pre-trained backbone. Defaults to True. + requires_grad (bool): Whether to require gradients for the backbone. Defaults to False. + Models like ``stfpm`` use the feature extractor model as a trainable network. In such cases gradient + computation is required. + + Example: + .. code-block:: python + + import torch + from anomalib.models.components.feature_extractors import TimmFeatureExtractor + + model = TimmFeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) + input = torch.rand((32, 3, 256, 256)) + features = model(input) + + print([layer for layer in features.keys()]) + # Output: ['layer1', 'layer2', 'layer3'] + + print([feature.shape for feature in features.values()]() + # Output: [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])] + """ + + def __init__( + self, + backbone: str, + layers: Sequence[str], + pre_trained: bool = True, + requires_grad: bool = False, + ) -> None: + super().__init__() + + # Extract backbone-name and weight-URI from the backbone string. + if "__AT__" in backbone: + backbone, uri = backbone.split("__AT__") + pretrained_cfg = timm.models.registry.get_pretrained_cfg(backbone) + # Override pretrained_cfg["url"] to use different pretrained weights. + pretrained_cfg["url"] = uri + else: + pretrained_cfg = None + + self.backbone = backbone + self.layers = list(layers) + self.idx = self._map_layer_to_idx() + self.requires_grad = requires_grad + self.feature_extractor = timm.create_model( + backbone, + pretrained=pre_trained, + pretrained_cfg=pretrained_cfg, + features_only=True, + exportable=True, + out_indices=self.idx, + ) + self.out_dims = self.feature_extractor.feature_info.channels() + self._features = {layer: torch.empty(0) for layer in self.layers} + + def _map_layer_to_idx(self) -> list[int]: + """Map set of layer names to indices of model. + + Returns: + list[int]: Feature map extracted from the CNN. + """ + idx = [] + model = timm.create_model( + self.backbone, + pretrained=False, + features_only=True, + exportable=True, + ) + # model.feature_info.info returns list of dicts containing info, inside which "module" contains layer name + layer_names = [info["module"] for info in model.feature_info.info] + for layer in self.layers: + try: + idx.append(layer_names.index(layer)) + except ValueError: # noqa: PERF203 + msg = f"Layer {layer} not found in model {self.backbone}. Available layers: {layer_names}" + logger.warning(msg) + # Remove unfound key from layer dict + self.layers.remove(layer) + + return idx + + def forward(self, inputs: torch.Tensor) -> dict[str, torch.Tensor]: + """Forward-pass input tensor into the CNN. + + Args: + inputs (torch.Tensor): Input tensor + + Returns: + Feature map extracted from the CNN + + Example: + .. code-block:: python + + model = TimmFeatureExtractor(model="resnet50", layers=['layer3']) + input = torch.rand((32, 3, 256, 256)) + features = model.forward(input) + + """ + if self.requires_grad: + features = dict(zip(self.layers, self.feature_extractor(inputs), strict=True)) + else: + self.feature_extractor.eval() + with torch.no_grad(): + features = dict(zip(self.layers, self.feature_extractor(inputs), strict=True)) + return features diff --git a/anomalib/models/components/feature_extractors/torchfx.py b/anomalib/models/components/feature_extractors/torchfx.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c42632ab666e6d3cff824421881fce09f42853 --- /dev/null +++ b/anomalib/models/components/feature_extractors/torchfx.py @@ -0,0 +1,234 @@ +"""Feature Extractor based on TorchFX.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import importlib +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch +from torch import nn +from torch.fx.graph_module import GraphModule +from torchvision.models._api import WeightsEnum +from torchvision.models.feature_extraction import create_feature_extractor + + +@dataclass +class BackboneParams: + """Used for serializing the backbone.""" + + class_path: str | type[nn.Module] + init_args: dict = field(default_factory=dict) + + +class TorchFXFeatureExtractor(nn.Module): + """Extract features from a CNN. + + Args: + backbone (str | BackboneParams | dict | nn.Module): The backbone to which the feature extraction hooks are + attached. If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be + provided and it will try to load the weights from the provided weights file. Last, an instance of nn.Module + can also be passed directly. + return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + You can find the names of these nodes by using ``get_graph_node_names`` function. + weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require + ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights + path for custom models. + requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should + set ``requires_grad`` to ``True``. Default is ``False``. + tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto + it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic + modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`. + + Example: + With torchvision models: + + .. code-block:: python + + import torch + from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + from torchvision.models.efficientnet import EfficientNet_B5_Weights + + feature_extractor = TorchFXFeatureExtractor( + backbone="efficientnet_b5", + return_nodes=["features.6.8"], + weights=EfficientNet_B5_Weights.DEFAULT + ) + + input = torch.rand((32, 3, 256, 256)) + features = feature_extractor(input) + + print([layer for layer in features.keys()]) + # Output: ["features.6.8"] + + print([feature.shape for feature in features.values()]) + # Output: [torch.Size([32, 304, 8, 8])] + + With custom models: + + .. code-block:: python + + import torch + from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + + feature_extractor = TorchFXFeatureExtractor( + "path.to.CustomModel", ["linear_relu_stack.3"], weights="path/to/weights.pth" + ) + + input = torch.randn(1, 1, 28, 28) + features = feature_extractor(input) + + print([layer for layer in features.keys()]) + # Output: ["linear_relu_stack.3"] + + with model instances: + + .. code-block:: python + + import torch + from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + from timm import create_model + + model = create_model("resnet18", pretrained=True) + feature_extractor = TorchFXFeatureExtractor(model, ["layer1"]) + + input = torch.rand((32, 3, 256, 256)) + features = feature_extractor(input) + + print([layer for layer in features.keys()]) + # Output: ["layer1"] + + print([feature.shape for feature in features.values()]) + # Output: [torch.Size([32, 64, 64, 64])] + """ + + def __init__( + self, + backbone: str | BackboneParams | dict | nn.Module, + return_nodes: list[str], + weights: str | WeightsEnum | None = None, + requires_grad: bool = False, + tracer_kwargs: dict | None = None, + ) -> None: + super().__init__() + if isinstance(backbone, dict): + backbone = BackboneParams(**backbone) + elif isinstance(backbone, str): + backbone = BackboneParams(class_path=backbone) + elif not isinstance(backbone, nn.Module | BackboneParams): + msg = f"backbone needs to be of type str | BackboneParams | dict | nn.Module, but was type {type(backbone)}" + raise TypeError(msg) + + self.feature_extractor = self.initialize_feature_extractor( + backbone, + return_nodes, + weights, + requires_grad, + tracer_kwargs, + ) + + def initialize_feature_extractor( + self, + backbone: BackboneParams | nn.Module, + return_nodes: list[str], + weights: str | WeightsEnum | None = None, + requires_grad: bool = False, + tracer_kwargs: dict | None = None, + ) -> GraphModule: + """Extract features from a CNN. + + Args: + backbone (BackboneParams | nn.Module): The backbone to which the feature extraction hooks are attached. + If the name is provided for BackboneParams, the model is loaded from torchvision. Otherwise, the model + class can be provided and it will try to load the weights from the provided weights file. Last, an + instance of the model can be provided as well, which will be used as-is. + return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + You can find the names of these nodes by using ``get_graph_node_names`` function. + weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require + ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights + path for custom models. + requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should + set ``requires_grad`` to ``True``. Default is ``False``. + tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto + it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic + modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`. + + Returns: + Feature Extractor based on TorchFX. + """ + if isinstance(backbone, nn.Module): + backbone_model = backbone + elif isinstance(backbone.class_path, str): + backbone_class = self._get_backbone_class(backbone.class_path) + backbone_model = backbone_class(weights=weights, **backbone.init_args) + else: + backbone_class = backbone.class_path + backbone_model = backbone_class(**backbone.init_args) + + if isinstance(weights, WeightsEnum): # torchvision models + feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes) + elif weights is not None: + if not isinstance(weights, str): + msg = "Weights should point to a path" + raise TypeError(msg) + + model_weights = torch.load(weights) + if "state_dict" in model_weights: + model_weights = model_weights["state_dict"] + backbone_model.load_state_dict(model_weights) + + feature_extractor = create_feature_extractor(backbone_model, return_nodes, tracer_kwargs=tracer_kwargs) + + if not requires_grad: + feature_extractor.eval() + for param in feature_extractor.parameters(): + param.requires_grad_(False) # noqa: FBT003 + + return feature_extractor + + @staticmethod + def _get_backbone_class(backbone: str) -> Callable[..., nn.Module]: + """Get the backbone class from the provided path. + + If only the model name is provided, it will try to load the model from torchvision. + + Example: + >>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + >>> TorchFXFeatureExtractor._get_backbone_class("efficientnet_b5") + torchvision.models.efficientnet.EfficientNet> + + >>> TorchFXFeatureExtractor._get_backbone_class("path.to.CustomModel") + + + Args: + backbone (str): Path to the backbone class. + + Returns: + Backbone class. + """ + try: + if len(backbone.split(".")) > 1: + # assumes that the entire class path is provided + models = importlib.import_module(".".join(backbone.split(".")[:-1])) + backbone_class = getattr(models, backbone.split(".")[-1]) + else: + models = importlib.import_module("torchvision.models") + backbone_class = getattr(models, backbone) + except ModuleNotFoundError as exception: + msg = f"Backbone {backbone} not found in torchvision.models nor in {backbone} module." + raise ModuleNotFoundError( + msg, + ) from exception + + return backbone_class + + def forward(self, inputs: torch.Tensor) -> dict[str, torch.Tensor]: + """Extract features from the input.""" + return self.feature_extractor(inputs) diff --git a/anomalib/models/components/feature_extractors/utils.py b/anomalib/models/components/feature_extractors/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71e50f7361943536cbc810cbd450a98c72e51bcc --- /dev/null +++ b/anomalib/models/components/feature_extractors/utils.py @@ -0,0 +1,29 @@ +"""Utility functions to manipulate feature extractors.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.fx.graph_module import GraphModule + +from .timm import TimmFeatureExtractor + + +def dryrun_find_featuremap_dims( + feature_extractor: TimmFeatureExtractor | GraphModule, + input_size: tuple[int, int], + layers: list[str], +) -> dict[str, dict[str, int | tuple[int, int]]]: + """Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution). + + Returns: + tuple[int, int]: maping of `layer -> dimensions dict` + Each `dimension dict` has two keys: `num_features` (int) and `resolution`(tuple[int, int]). + """ + device = next(feature_extractor.parameters()).device + dryrun_input = torch.empty(1, 3, *input_size).to(device) + dryrun_features = feature_extractor(dryrun_input) + return { + layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]} + for layer in layers + } diff --git a/anomalib/models/components/filters/__init__.py b/anomalib/models/components/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2878948759cff466a13f7df1b8a5c07e3516a8be --- /dev/null +++ b/anomalib/models/components/filters/__init__.py @@ -0,0 +1,5 @@ +"""Implements filters used by models.""" + +from .blur import GaussianBlur2d + +__all__ = ["GaussianBlur2d"] diff --git a/anomalib/models/components/filters/blur.py b/anomalib/models/components/filters/blur.py new file mode 100644 index 0000000000000000000000000000000000000000..986214707db7dc047dad60f740b7edbb3ef7cde1 --- /dev/null +++ b/anomalib/models/components/filters/blur.py @@ -0,0 +1,98 @@ +"""Gaussian blurring via pytorch.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from kornia.filters import get_gaussian_kernel2d +from kornia.filters.filter import _compute_padding +from kornia.filters.kernels import normalize_kernel2d +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +def compute_kernel_size(sigma_val: float) -> int: + """Compute kernel size from sigma value. + + Args: + sigma_val (float): Sigma value. + + Returns: + int: Kernel size. + """ + return 2 * int(4.0 * sigma_val + 0.5) + 1 + + +class GaussianBlur2d(nn.Module): + """Compute GaussianBlur in 2d. + + Makes use of kornia functions, but most notably the kernel is not computed + during the forward pass, and does not depend on the input size. As a caveat, + the number of channels that are expected have to be provided during initialization. + """ + + def __init__( + self, + sigma: float | tuple[float, float], + channels: int = 1, + kernel_size: int | tuple[int, int] | None = None, + normalize: bool = True, + border_type: str = "reflect", + padding: str = "same", + ) -> None: + """Initialize model, setup kernel etc.. + + Args: + sigma (float | tuple[float, float]): standard deviation to use for constructing the Gaussian kernel. + channels (int): channels of the input. Defaults to 1. + kernel_size (int | tuple[int, int] | None): size of the Gaussian kernel to use. Defaults to None. + normalize (bool, optional): Whether to normalize the kernel or not (i.e. all elements sum to 1). + Defaults to True. + border_type (str, optional): Border type to use for padding of the input. Defaults to "reflect". + padding (str, optional): Type of padding to apply. Defaults to "same". + """ + super().__init__() + sigma = sigma if isinstance(sigma, tuple) else (sigma, sigma) + self.channels = channels + + if kernel_size is None: + kernel_size = (compute_kernel_size(sigma[0]), compute_kernel_size(sigma[1])) + else: + kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) + + self.kernel: torch.Tensor + self.register_buffer("kernel", get_gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma)) + if normalize: + self.kernel = normalize_kernel2d(self.kernel) + + self.kernel = self.kernel.view(1, 1, *self.kernel.shape[-2:]) + + self.kernel = self.kernel.expand(self.channels, -1, -1, -1) + self.border_type = border_type + self.padding = padding + self.height, self.width = self.kernel.shape[-2:] + self.padding_shape = _compute_padding([self.height, self.width]) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Blur the input with the computed Gaussian. + + Args: + input_tensor (torch.Tensor): Input tensor to be blurred. + + Returns: + Tensor: Blurred output tensor. + """ + batch, channel, height, width = input_tensor.size() + + if self.padding == "same": + input_tensor = F.pad(input_tensor, self.padding_shape, mode=self.border_type) + + # convolve the tensor with the kernel. + output = F.conv2d(input_tensor, self.kernel, groups=self.channels, padding=0, stride=1) + + if self.padding == "same": + out = output.view(batch, channel, height, width) + else: + out = output.view(batch, channel, height - self.height + 1, width - self.width + 1) + + return out diff --git a/anomalib/models/components/flow/__init__.py b/anomalib/models/components/flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dca2e7b9e620820c64f17dbec2cd06ae548ef7cf --- /dev/null +++ b/anomalib/models/components/flow/__init__.py @@ -0,0 +1,8 @@ +"""All In One Block Layer.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .all_in_one_block import AllInOneBlock + +__all__ = ["AllInOneBlock"] diff --git a/anomalib/models/components/flow/all_in_one_block.py b/anomalib/models/components/flow/all_in_one_block.py new file mode 100644 index 0000000000000000000000000000000000000000..0f517c2552a2e84ba527b32d72ae387dc87df603 --- /dev/null +++ b/anomalib/models/components/flow/all_in_one_block.py @@ -0,0 +1,343 @@ +"""All In One Block Layer.""" + +# Copyright (c) https://github.com/vislearn/FrEIA +# SPDX-License-Identifier: MIT + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections.abc import Callable +from typing import Any + +import torch +from FrEIA.modules import InvertibleModule +from scipy.stats import special_ortho_group +from torch import nn +from torch.nn import functional as F # noqa: N812 + +logger = logging.getLogger(__name__) + + +def _global_scale_sigmoid_activation(input_tensor: torch.Tensor) -> torch.Tensor: + """Global scale sigmoid activation. + + Args: + input_tensor (torch.Tensor): Input tensor + + Returns: + Tensor: Sigmoid activation + """ + return 10 * torch.sigmoid(input_tensor - 2.0) + + +def _global_scale_softplus_activation(input_tensor: torch.Tensor) -> torch.Tensor: + """Global scale softplus activation. + + Args: + input_tensor (torch.Tensor): Input tensor + + Returns: + Tensor: Softplus activation + """ + softplus = nn.Softplus(beta=0.5) + return 0.1 * softplus(input_tensor) + + +def _global_scale_exp_activation(input_tensor: torch.Tensor) -> torch.Tensor: + """Global scale exponential activation. + + Args: + input_tensor (torch.Tensor): Input tensor + + Returns: + Tensor: Exponential activation + """ + return torch.exp(input_tensor) + + +class AllInOneBlock(InvertibleModule): + r"""Module combining the most common operations in a normalizing flow or similar model. + + It combines affine coupling, permutation, and global affine transformation + ('ActNorm'). It can also be used as GIN coupling block, perform learned + householder permutations, and use an inverted pre-permutation. The affine + transformation includes a soft clamping mechanism, first used in Real-NVP. + The block as a whole performs the following computation: + + .. math:: + + y = V R \; \Psi(s_\mathrm{global}) \odot \mathrm{Coupling}\Big(R^{-1} V^{-1} x\Big)+ t_\mathrm{global} + + - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see + ``reverse_permutation`` below). + - The learned householder reflection matrix + :math:`V` is also optional all together (see ``learned_householder_permutation`` + below). + - For the coupling, the input is split into :math:`x_1, x_2` along + the channel dimension. Then the output of the coupling operation is the + two halves :math:`u = \mathrm{concat}(u_1, u_2)`. + + .. math:: + + u_1 &= x_1 \odot \exp \Big( \alpha \; \mathrm{tanh}\big( s(x_2) \big)\Big) + t(x_2) \\ + u_2 &= x_2 + + Because :math:`\mathrm{tanh}(s) \in [-1, 1]`, this clamping mechanism prevents + exploding values in the exponential. The hyperparameter :math:`\alpha` can be adjusted. + + Args: + subnet_constructor: class or callable ``f``, called as ``f(channels_in, channels_out)`` and + should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. + affine_clamping: clamp the output of the multiplicative coefficients before + exponentiation to +/- ``affine_clamping`` (see :math:`\alpha` above). + gin_block: Turn the block into a GIN block from Sorrenson et al, 2019. + Makes it so that the coupling operations as a whole is volume preserving. + global_affine_init: Initial value for the global affine scaling :math:`s_\mathrm{global}`. + global_affine_init: ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used + on the beta for the global affine scaling (:math:`\Psi` above). + permute_soft: bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, + or to use hard permutations instead. Note, ``permute_soft=True`` is very slow + when working with >512 dimensions. + learned_householder_permutation: Int, if >0, turn on the matrix :math:`V` above, that represents + multiple learned householder reflections. Slow if large number. + Dubious whether it actually helps network performance. + reverse_permutation: Reverse the permutation before the block, as introduced by Putzky + et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. + """ + + def __init__( + self, + dims_in: list[tuple[int]], + dims_c: list[tuple[int]] | None = None, + subnet_constructor: Callable | None = None, + affine_clamping: float = 2.0, + gin_block: bool = False, + global_affine_init: float = 1.0, + global_affine_type: str = "SOFTPLUS", + permute_soft: bool = False, + learned_householder_permutation: int = 0, + reverse_permutation: bool = False, + ) -> None: + if dims_c is None: + dims_c = [] + super().__init__(dims_in, dims_c) + + channels = dims_in[0][0] + # rank of the tensors means 1d, 2d, 3d tensor etc. + self.input_rank = len(dims_in[0]) - 1 + # tuple containing all dims except for batch-dim (used at various points) + self.sum_dims = tuple(range(1, 2 + self.input_rank)) + + if len(dims_c) == 0: + self.conditional = False + self.condition_channels = 0 + else: + if tuple(dims_c[0][1:]) != tuple(dims_in[0][1:]): + msg = f"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}." + raise ValueError(msg) + + self.conditional = True + self.condition_channels = sum(dc[0] for dc in dims_c) + + split_len1 = channels - channels // 2 + split_len2 = channels // 2 + self.splits = [split_len1, split_len2] + + try: + self.permute_function = {0: F.linear, 1: F.conv1d, 2: F.conv2d, 3: F.conv3d}[self.input_rank] + except KeyError: + msg = f"Data is {1 + self.input_rank}D. Must be 1D-4D." + raise ValueError(msg) from None + + self.in_channels = channels + self.clamp = affine_clamping + self.GIN = gin_block + self.reverse_pre_permute = reverse_permutation + self.householder = learned_householder_permutation + + if permute_soft and channels > 512: + msg = ( + "Soft permutation will take a very long time to initialize " + f"with {channels} feature channels. Consider using hard permutation instead." + ) + logger.warning(msg) + + # global_scale is used as the initial value for the global affine scale + # (pre-activation). It is computed such that + # the 'magic numbers' (specifically for sigmoid) scale the activation to + # a sensible range. + if global_affine_type == "SIGMOID": + global_scale = 2.0 - torch.log(torch.tensor([10.0 / global_affine_init - 1.0])) + self.global_scale_activation = _global_scale_sigmoid_activation + elif global_affine_type == "SOFTPLUS": + global_scale = 2.0 * torch.log(torch.exp(torch.tensor(0.5 * 10.0 * global_affine_init)) - 1) + self.global_scale_activation = _global_scale_softplus_activation + elif global_affine_type == "EXP": + global_scale = torch.log(torch.tensor(global_affine_init)) + self.global_scale_activation = _global_scale_exp_activation + else: + message = 'Global affine activation must be "SIGMOID", "SOFTPLUS" or "EXP"' + raise ValueError(message) + + self.global_scale = nn.Parameter(torch.ones(1, self.in_channels, *([1] * self.input_rank)) * global_scale) + self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.input_rank))) + + if permute_soft: + w = special_ortho_group.rvs(channels) + else: + indices = torch.randperm(channels) + w = torch.zeros((channels, channels)) + w[torch.arange(channels), indices] = 1.0 + + if self.householder: + # instead of just the permutation matrix w, the learned housholder + # permutation keeps track of reflection vectors vk, in addition to a + # random initial permutation w_0. + self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) + self.w_perm = None + self.w_perm_inv = None + self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) + else: + self.w_perm = nn.Parameter( + torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), + requires_grad=False, + ) + self.w_perm_inv = nn.Parameter( + torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), + requires_grad=False, + ) + + if subnet_constructor is None: + message = "Please supply a callable subnet_constructor function or object (see docstring)" + raise ValueError(message) + self.subnet = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1]) + self.last_jac = None + + def _construct_householder_permutation(self) -> torch.Tensor: + """Compute a permutation matrix from the reflection vectors that are learned internally as nn.Parameters.""" + w = self.w_0 + for vk in self.vk_householder: + w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) + + for _ in range(self.input_rank): + w = w.unsqueeze(-1) + return w + + def _permute(self, x: torch.Tensor, rev: bool = False) -> tuple[Any, float | torch.Tensor]: + """Perform the permutation and scaling after the coupling operation. + + Returns transformed outputs and the LogJacDet of the scaling operation. + + Args: + x (torch.Tensor): Input tensor + rev (bool, optional): Reverse the permutation. Defaults to False. + + Returns: + tuple[Any, float | torch.Tensor]: Transformed outputs and the LogJacDet of the scaling operation. + """ + if self.GIN: + scale = 1.0 + perm_log_jac = 0.0 + else: + scale = self.global_scale_activation(self.global_scale) + perm_log_jac = torch.sum(torch.log(scale)) + + if rev: + return ((self.permute_function(x, self.w_perm_inv) - self.global_offset) / scale, perm_log_jac) + + return (self.permute_function(x * scale + self.global_offset, self.w_perm), perm_log_jac) + + def _pre_permute(self, x: torch.Tensor, rev: bool = False) -> torch.Tensor: + """Permute before the coupling block. + + It is only used if reverse_permutation is set. + """ + if rev: + return self.permute_function(x, self.w_perm) + + return self.permute_function(x, self.w_perm_inv) + + def _affine(self, x: torch.Tensor, a: torch.Tensor, rev: bool = False) -> tuple[Any, torch.Tensor]: + """Perform affine coupling operation. + + Given the passive half, and the pre-activation outputs of the + coupling subnetwork, perform the affine coupling operation. + Returns both the transformed inputs and the LogJacDet. + """ + # the entire coupling coefficient tensor is scaled down by a + # factor of ten for stability and easier initialization. + a *= 0.1 + ch = x.shape[1] + + sub_jac = self.clamp * torch.tanh(a[:, :ch]) + if self.GIN: + sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) + + if not rev: + return (x * torch.exp(sub_jac) + a[:, ch:], torch.sum(sub_jac, dim=self.sum_dims)) + + return ((x - a[:, ch:]) * torch.exp(-sub_jac), -torch.sum(sub_jac, dim=self.sum_dims)) + + def forward( + self, + x: torch.Tensor, + c: list | None = None, + rev: bool = False, + jac: bool = True, + ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + """See base class docstring.""" + del jac # Unused argument. + + if c is None: + c = [] + + if self.householder: + self.w_perm = self._construct_householder_permutation() + if rev or self.reverse_pre_permute: + self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous() + + if rev: + x, global_scaling_jac = self._permute(x[0], rev=True) + x = (x,) + elif self.reverse_pre_permute: + x = (self._pre_permute(x[0], rev=False),) + + x1, x2 = torch.split(x[0], self.splits, dim=1) + + x1c = torch.cat([x1, *c], 1) if self.conditional else x1 + + if not rev: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1) + else: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1, rev=True) + + log_jac_det = j2 + x_out = torch.cat((x1, x2), 1) + + if not rev: + x_out, global_scaling_jac = self._permute(x_out, rev=False) + elif self.reverse_pre_permute: + x_out = self._pre_permute(x_out, rev=True) + + # add the global scaling Jacobian to the total. + # trick to get the total number of non-channel dimensions: + # number of elements of the first channel of the first batch member + n_pixels = x_out[0, :1].numel() + log_jac_det += (-1) ** rev * n_pixels * global_scaling_jac + + return (x_out,), log_jac_det + + def output_dims(self, input_dims: list[tuple[int]]) -> list[tuple[int]]: + """Output dimensions of the layer. + + Args: + input_dims (list[tuple[int]]): Input dimensions. + + Returns: + list[tuple[int]]: Output dimensions. + """ + return input_dims diff --git a/anomalib/models/components/layers/__init__.py b/anomalib/models/components/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2937cfe0c7b8d52e2b06846193af8d2db73f9c0 --- /dev/null +++ b/anomalib/models/components/layers/__init__.py @@ -0,0 +1,8 @@ +"""Neural network layers.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .sspcab import SSPCAB + +__all__ = ["SSPCAB"] diff --git a/anomalib/models/components/layers/sspcab.py b/anomalib/models/components/layers/sspcab.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8ce4e8b56aaa3fe6f514b597bbee26d0875065 --- /dev/null +++ b/anomalib/models/components/layers/sspcab.py @@ -0,0 +1,78 @@ +"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models. + +Paper https://arxiv.org/abs/2111.09099 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class AttentionModule(nn.Module): + """Squeeze and excitation block that acts as the attention module in SSPCAB. + + Args: + channels (int): Number of input channels. + reduction_ratio (int): Reduction ratio of the attention module. + """ + + def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None: + super().__init__() + + out_channels = in_channels // reduction_ratio + self.fc1 = nn.Linear(in_channels, out_channels) + self.fc2 = nn.Linear(out_channels, in_channels) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Forward pass through the attention module.""" + # reduce feature map to 1d vector through global average pooling + avg_pooled = inputs.mean(dim=(2, 3)) + + # squeeze and excite + act = self.fc1(avg_pooled) + act = F.relu(act) + act = self.fc2(act) + act = F.sigmoid(act) + + # multiply with input + return inputs * act.view(act.shape[0], act.shape[1], 1, 1) + + +class SSPCAB(nn.Module): + """SSPCAB block. + + Args: + in_channels (int): Number of input channels. + kernel_size (int): Size of the receptive fields of the masked convolution kernel. + dilation (int): Dilation factor of the masked convolution kernel. + reduction_ratio (int): Reduction ratio of the attention module. + """ + + def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8) -> None: + super().__init__() + + self.pad = kernel_size + dilation + self.crop = kernel_size + 2 * dilation + 1 + + self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) + self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) + self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) + self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) + + self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Forward pass through the SSPCAB block.""" + # compute masked convolution + padded = F.pad(inputs, (self.pad,) * 4) + masked_out = torch.zeros_like(inputs) + masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop]) + masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :]) + masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop]) + masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :]) + + # apply channel attention module + return self.attention_module(masked_out) diff --git a/anomalib/models/components/sampling/__init__.py b/anomalib/models/components/sampling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47c842123f0bb8ff49b77cb0581016dcb57f789e --- /dev/null +++ b/anomalib/models/components/sampling/__init__.py @@ -0,0 +1,8 @@ +"""Sampling methods.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .k_center_greedy import KCenterGreedy + +__all__ = ["KCenterGreedy"] diff --git a/anomalib/models/components/sampling/k_center_greedy.py b/anomalib/models/components/sampling/k_center_greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..788f2e66834f61a607dfd9779bf59b8a8b98b6c1 --- /dev/null +++ b/anomalib/models/components/sampling/k_center_greedy.py @@ -0,0 +1,130 @@ +"""k-Center Greedy Method. + +Returns points that minimizes the maximum distance of any point to a center. +- https://arxiv.org/abs/1708.00489 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from rich.progress import track +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components.dimensionality_reduction import SparseRandomProjection + + +class KCenterGreedy: + """Implements k-center-greedy method. + + Args: + embedding (torch.Tensor): Embedding vector extracted from a CNN + sampling_ratio (float): Ratio to choose coreset size from the embedding size. + + Example: + >>> embedding.shape + torch.Size([219520, 1536]) + >>> sampler = KCenterGreedy(embedding=embedding) + >>> sampled_idxs = sampler.select_coreset_idxs() + >>> coreset = embedding[sampled_idxs] + >>> coreset.shape + torch.Size([219, 1536]) + """ + + def __init__(self, embedding: torch.Tensor, sampling_ratio: float) -> None: + self.embedding = embedding + self.coreset_size = int(embedding.shape[0] * sampling_ratio) + self.model = SparseRandomProjection(eps=0.9) + + self.features: torch.Tensor + self.min_distances: torch.Tensor = None + self.n_observations = self.embedding.shape[0] + + def reset_distances(self) -> None: + """Reset minimum distances.""" + self.min_distances = None + + def update_distances(self, cluster_centers: list[int]) -> None: + """Update min distances given cluster centers. + + Args: + cluster_centers (list[int]): indices of cluster centers + """ + if cluster_centers: + centers = self.features[cluster_centers] + + distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) + + if self.min_distances is None: + self.min_distances = distance + else: + self.min_distances = torch.minimum(self.min_distances, distance) + + def get_new_idx(self) -> int: + """Get index value of a sample. + + Based on minimum distance of the cluster + + Returns: + int: Sample index + """ + if isinstance(self.min_distances, torch.Tensor): + idx = int(torch.argmax(self.min_distances).item()) + else: + msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}" + raise TypeError(msg) + + return idx + + def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: + """Greedily form a coreset to minimize the maximum distance of a cluster. + + Args: + selected_idxs: index of samples already selected. Defaults to an empty set. + + Returns: + indices of samples selected to minimize distance to cluster centers + """ + if selected_idxs is None: + selected_idxs = [] + + if self.embedding.ndim == 2: + self.model.fit(self.embedding) + self.features = self.model.transform(self.embedding) + self.reset_distances() + else: + self.features = self.embedding.reshape(self.embedding.shape[0], -1) + self.update_distances(cluster_centers=selected_idxs) + + selected_coreset_idxs: list[int] = [] + idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) + for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."): + self.update_distances(cluster_centers=[idx]) + idx = self.get_new_idx() + if idx in selected_idxs: + msg = "New indices should not be in selected indices." + raise ValueError(msg) + self.min_distances[idx] = 0 + selected_coreset_idxs.append(idx) + + return selected_coreset_idxs + + def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor: + """Select coreset from the embedding. + + Args: + selected_idxs: index of samples already selected. Defaults to an empty set. + + Returns: + Tensor: Output coreset + + Example: + >>> embedding.shape + torch.Size([219520, 1536]) + >>> sampler = KCenterGreedy(...) + >>> coreset = sampler.sample_coreset() + >>> coreset.shape + torch.Size([219, 1536]) + """ + idxs = self.select_coreset_idxs(selected_idxs) + return self.embedding[idxs] diff --git a/anomalib/models/components/stats/__init__.py b/anomalib/models/components/stats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65aef1caf900de651d4179297f4f888761b19b3 --- /dev/null +++ b/anomalib/models/components/stats/__init__.py @@ -0,0 +1,9 @@ +"""Statistical functions.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .kde import GaussianKDE +from .multi_variate_gaussian import MultiVariateGaussian + +__all__ = ["GaussianKDE", "MultiVariateGaussian"] diff --git a/anomalib/models/components/stats/kde.py b/anomalib/models/components/stats/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..da9d4da5785a546ca91e2f6eed7fa10fce344df2 --- /dev/null +++ b/anomalib/models/components/stats/kde.py @@ -0,0 +1,95 @@ +"""Gaussian Kernel Density Estimation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import math + +import torch + +from anomalib.models.components.base import DynamicBufferMixin + + +class GaussianKDE(DynamicBufferMixin): + """Gaussian Kernel Density Estimation. + + Args: + dataset (Tensor | None, optional): Dataset on which to fit the KDE model. Defaults to None. + """ + + def __init__(self, dataset: torch.Tensor | None = None) -> None: + super().__init__() + + if dataset is not None: + self.fit(dataset) + + self.register_buffer("bw_transform", torch.Tensor()) + self.register_buffer("dataset", torch.Tensor()) + self.register_buffer("norm", torch.Tensor()) + + self.bw_transform = torch.Tensor() + self.dataset = torch.Tensor() + self.norm = torch.Tensor() + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Get the KDE estimates from the feature map. + + Args: + features (torch.Tensor): Feature map extracted from the CNN + + Returns: KDE Estimates + """ + features = torch.matmul(features, self.bw_transform) + + estimate = torch.zeros(features.shape[0]).to(features.device) + for i in range(features.shape[0]): + embedding = ((self.dataset - features[i]) ** 2).sum(dim=1) + embedding = torch.exp(-embedding / 2) * self.norm + estimate[i] = torch.mean(embedding) + + return estimate + + def fit(self, dataset: torch.Tensor) -> None: + """Fit a KDE model to the input dataset. + + Args: + dataset (torch.Tensor): Input dataset. + + Returns: + None + """ + num_samples, dimension = dataset.shape + + # compute scott's bandwidth factor + factor = num_samples ** (-1 / (dimension + 4)) + + cov_mat = self.cov(dataset.T) + inv_cov_mat = torch.linalg.inv(cov_mat) + inv_cov = inv_cov_mat / factor**2 + + # transform data to account for bandwidth + bw_transform = torch.linalg.cholesky(inv_cov) + dataset = torch.matmul(dataset, bw_transform) + + # + norm = torch.prod(torch.diag(bw_transform)) + norm *= math.pow((2 * math.pi), (-dimension / 2)) + + self.bw_transform = bw_transform + self.dataset = dataset + self.norm = norm + + @staticmethod + def cov(tensor: torch.Tensor) -> torch.Tensor: + """Calculate the unbiased covariance matrix. + + Args: + tensor (torch.Tensor): Input tensor from which covariance matrix is computed. + + Returns: + Output covariance matrix. + """ + mean = torch.mean(tensor, dim=1) + tensor -= mean[:, None] + return torch.matmul(tensor, tensor.T) / (tensor.size(1) - 1) diff --git a/anomalib/models/components/stats/multi_variate_gaussian.py b/anomalib/models/components/stats/multi_variate_gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..b05edfb827755e0497a2d24303ac96d4d7e03567 --- /dev/null +++ b/anomalib/models/components/stats/multi_variate_gaussian.py @@ -0,0 +1,136 @@ +"""Multi Variate Gaussian Distribution.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from torch import nn + +from anomalib.models.components.base import DynamicBufferMixin + + +class MultiVariateGaussian(DynamicBufferMixin, nn.Module): + """Multi Variate Gaussian Distribution.""" + + def __init__(self) -> None: + super().__init__() + + self.register_buffer("mean", torch.empty(0)) + self.register_buffer("inv_covariance", torch.empty(0)) + + self.mean: torch.Tensor + self.inv_covariance: torch.Tensor + + @staticmethod + def _cov( + observations: torch.Tensor, + rowvar: bool = False, + bias: bool = False, + ddof: int | None = None, + aweights: torch.Tensor | None = None, + ) -> torch.Tensor: + """Estimates covariance matrix like numpy.cov. + + Args: + observations (torch.Tensor): A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. Also see `rowvar` below. + rowvar (bool): If `rowvar` is True (default), then each row represents a + variable, with observations in the columns. Otherwise, the relationship + is transposed: each column represents a variable, while the rows + contain observations. Defaults to False. + bias (bool): Default normalization (False) is by ``(N - 1)``, where ``N`` is the + number of observations given (unbiased estimate). If `bias` is True, + then normalization is by ``N``. These values can be overridden by using + the keyword ``ddof`` in numpy versions >= 1.5. Defaults to False + ddof (int | None): If not ``None`` the default value implied by `bias` is overridden. + Note that ``ddof=1`` will return the unbiased estimate, even if both + `fweights` and `aweights` are specified, and ``ddof=0`` will return + the simple average. See the notes for the details. The default value + is ``None``. + aweights (torch.Tensor): 1-D array of observation vector weights. These relative weights are + typically large for observations considered "important" and smaller for + observations considered less "important". If ``ddof=0`` the array of + weights can be used to assign probabilities to observation vectors. (Default value = None) + + + Returns: + The covariance matrix of the variables. + """ + # ensure at least 2D + if observations.dim() == 1: + observations = observations.view(-1, 1) + + # treat each column as a data point, each row as a variable + if rowvar and observations.shape[0] != 1: + observations = observations.t() + + if ddof is None: + ddof = 1 if bias == 0 else 0 + + weights = aweights + weights_sum: Any + + if weights is not None: + if not torch.is_tensor(weights): + weights = torch.tensor(weights, dtype=torch.float) # pylint: disable=not-callable + weights_sum = torch.sum(weights) + avg = torch.sum(observations * (weights / weights_sum)[:, None], 0) + else: + avg = torch.mean(observations, 0) + + # Determine the normalization + if weights is None: + fact = observations.shape[0] - ddof + elif ddof == 0: + fact = weights_sum + elif aweights is None: + fact = weights_sum - ddof + else: + fact = weights_sum - ddof * torch.sum(weights * weights) / weights_sum + + observations_m = observations.sub(avg.expand_as(observations)) + + x_transposed = observations_m.t() if weights is None else torch.mm(torch.diag(weights), observations_m).t() + + covariance = torch.mm(x_transposed, observations_m) + covariance = covariance / fact + + return covariance.squeeze() + + def forward(self, embedding: torch.Tensor) -> list[torch.Tensor]: + """Calculate multivariate Gaussian distribution. + + Args: + embedding (torch.Tensor): CNN features whose dimensionality is reduced via either random sampling or PCA. + + Returns: + mean and inverse covariance of the multi-variate gaussian distribution that fits the features. + """ + device = embedding.device + + batch, channel, height, width = embedding.size() + embedding_vectors = embedding.view(batch, channel, height * width) + self.mean = torch.mean(embedding_vectors, dim=0) + covariance = torch.zeros(size=(channel, channel, height * width), device=device) + identity = torch.eye(channel).to(device) + for i in range(height * width): + covariance[:, :, i] = self._cov(embedding_vectors[:, :, i], rowvar=False) + 0.01 * identity + + # calculate inverse covariance as we need only the inverse + self.inv_covariance = torch.linalg.inv(covariance.permute(2, 0, 1)) + + return [self.mean, self.inv_covariance] + + def fit(self, embedding: torch.Tensor) -> list[torch.Tensor]: + """Fit multi-variate gaussian distribution to the input embedding. + + Args: + embedding (torch.Tensor): Embedding vector extracted from CNN. + + Returns: + Mean and the covariance of the embedding. + """ + return self.forward(embedding) diff --git a/anomalib/models/image/README.md b/anomalib/models/image/README.md new file mode 100644 index 0000000000000000000000000000000000000000..683a102bc280dfa3374b139b694e4de5d96b856f --- /dev/null +++ b/anomalib/models/image/README.md @@ -0,0 +1,61 @@ +# Anomalib Image Models + +## 📝 Description + +This sub-package contains the models for handling image datasets in anomalib. + +The anomalib.models.image subpackage provides: + +- Classes and functions to define image anomaly models. +- Models for image-based anomaly classification, detection or segmentation. + +## ⚠️ Note + +The models in anomalib.models.image can also handle video datasets by converting them to frame-based image datasets. +This feature allows the application of the same models and techniques to video data. + +## 💡 Examples + +
+Using the EfficientAD model on an Image Dataset such as Visa + +```python +# Import the necessary modules +from anomalib.data import Visa +from anomalib.models import EfficientAD +from anomalib.engine import Engine + +# Load the ViSA dataset, model and engine. +datamodule = Visa() +model = EfficientAD() +engine = Engine() + +# Train the model +engine.train(model, datamodule) +``` + +
+ +
+Using the EfficientAD model on a Video Dataset such as Avenue + +To use an image model to train on a video dataset, we need to convert the video dataset to a frame-based image dataset. To do this, we could use `clip_length_in_frames=1` when loading the dataset. + +```python +# Import the necessary modules +from anomalib.data import Avenue +from anomalib.models import EfficientAD +from anomalib.engine import Engine + +# Load the folder, model and engine. +# Set the clip_length_in_frames to 1 to convert the video dataset to a +# frame-based image dataset. +datamodule = Avenue(clip_length_in_frames=1) +model = EfficientAD() +engine = Engine() + +# Train the model +engine.train(model, datamodule) +``` + +
diff --git a/anomalib/models/image/__init__.py b/anomalib/models/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8478747f014a91b7b30e66c55a197194789178a6 --- /dev/null +++ b/anomalib/models/image/__init__.py @@ -0,0 +1,42 @@ +"""Anomalib Image Models.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .cfa import Cfa +from .cflow import Cflow +from .csflow import Csflow +from .dfkde import Dfkde +from .dfm import Dfm +from .draem import Draem +from .dsr import Dsr +from .efficient_ad import EfficientAd +from .fastflow import Fastflow +from .ganomaly import Ganomaly +from .padim import Padim +from .patchcore import Patchcore +from .reverse_distillation import ReverseDistillation +from .rkde import Rkde +from .stfpm import Stfpm +from .uflow import Uflow +from .winclip import WinClip + +__all__ = [ + "Cfa", + "Cflow", + "Csflow", + "Dfkde", + "Dfm", + "Draem", + "Dsr", + "EfficientAd", + "Fastflow", + "Ganomaly", + "Padim", + "Patchcore", + "ReverseDistillation", + "Rkde", + "Stfpm", + "Uflow", + "WinClip", +] diff --git a/anomalib/models/image/cfa/README.md b/anomalib/models/image/cfa/README.md new file mode 100755 index 0000000000000000000000000000000000000000..d174bfe26f0a31f3632f1538e22348bdffaf0a33 --- /dev/null +++ b/anomalib/models/image/cfa/README.md @@ -0,0 +1,161 @@ +# CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cfa-coupled-hypersphere-based-feature/anomaly-detection-on-mvtec-ad)](https://paperswithcode.com/sota/anomaly-detection-on-mvtec-ad?p=cfa-coupled-hypersphere-based-feature) + +This is the implementation of the [CFA](https://arxiv.org/abs/2206.04325) paper. The original implementation could be found [sungwool/cfa_for_anomaly_localization](https://github.com/sungwool/cfa_for_anomaly_localization). + +Model Type: Segmentation + +## Description + +Coupled-hypersphere-based Feature Adaptation (CFA) localizes anomalies using features adapted to the target dataset. CFA consists of (1) a learnable patch descriptor that learns and embeds target-oriented features and (2) a scalable memory bank independent of the size of the target dataset. By applying a patch descriptor and memory bank to a pretrained CNN, CFA also employs transfer learning to increase the normal feature density so that abnormal features can be easily distinguished. + +## Architecture + +![Cfa Architecture](/docs/source/images/cfa/architecture.png "Cfa Architecture") + +## Usage + +`python tools/train.py --model cfa` + +## Benchmark + +All results gathered with seed `0`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +--- + +### NOTE + +When the numbers are produced, early stopping callback (patience: 5) is used. It might be possible to achieve higher-metrics by increasing the patience. + +--- + +### Image-Level AUC + +| | ResNet-18 | Wide ResNet50 | +| ---------- | :-------: | :-----------: | +| Bottle | 0.991 | 0.998 | +| Cable | 0.947 | 0.979 | +| Capsule | 0.858 | 0.872 | +| Carpet | 0.953 | 0.978 | +| Grid | 0.947 | 0.961 | +| Hazelnut | 0.995 | 1.000 | +| Leather | 0.999 | 0.990 | +| Metal_nut | 0.932 | 0.995 | +| Pill | 0.887 | 0.946 | +| Screw | 0.625 | 0.703 | +| Tile | 1.000 | 0.999 | +| Toothbrush | 0.994 | 1.000 | +| Transistor | 0.895 | 0.957 | +| Wood | 1.000 | 0.994 | +| Zipper | 0.919 | 0.967 | +| Average | 0.930 | 0.956 | + +### Image F1 Score + +| | ResNet-18 | Wide ResNet50 | +| ---------- | :-------: | :-----------: | +| Bottle | 0.983 | 0.984 | +| Cable | 0.907 | 0.962 | +| Capsule | 0.938 | 0.946 | +| Carpet | 0.956 | 0.961 | +| Grid | 0.946 | 0.957 | +| Hazelnut | 0.996 | 1.000 | +| Leather | 0.995 | 0.973 | +| Metal_nut | 0.958 | 0.984 | +| Pill | 0.920 | 0.952 | +| Screw | 0.858 | 0.855 | +| Tile | 1.000 | 0.994 | +| Toothbrush | 0.984 | 1.000 | +| Transistor | 0.795 | 0.907 | +| Wood | 1.000 | 0.983 | +| Zipper | 0.949 | 0.975 | +| Average | 0.946 | 0.962 | + +### Pixel-Level AUC + +| | ResNet-18 | Wide ResNet50 | +| ---------- | :-------: | :-----------: | +| Bottle | 0.986 | 0.989 | +| Cable | 0.984 | 0.988 | +| Capsule | 0.987 | 0.989 | +| Carpet | 0.970 | 0.980 | +| Grid | 0.973 | 0.954 | +| Hazelnut | 0.987 | 0.985 | +| Leather | 0.992 | 0.989 | +| Metal_nut | 0.981 | 0.992 | +| Pill | 0.981 | 0.988 | +| Screw | 0.973 | 0.979 | +| Tile | 0.978 | 0.985 | +| Toothbrush | 0.990 | 0.991 | +| Transistor | 0.964 | 0.977 | +| Wood | 0.964 | 0.974 | +| Zipper | 0.978 | 0.990 | +| Average | 0.979 | 0.983 | + +### Pixel-Level AUPRO + +| | ResNet-18 | Wide ResNet50 | +| ---------- | :-------: | :-----------: | +| Bottle | 0.940 | 0.947 | +| Cable | 0.902 | 0.940 | +| Capsule | 0.946 | 0.939 | +| Carpet | 0.910 | 0.919 | +| Grid | 0.911 | 0.862 | +| Hazelnut | 0.931 | 0.930 | +| Leather | 0.974 | 0.955 | +| Metal_nut | 0.912 | 0.931 | +| Pill | 0.935 | 0.947 | +| Screw | 0.884 | 0.906 | +| Tile | 0.892 | 0.906 | +| Toothbrush | 0.895 | 0.899 | +| Transistor | 0.895 | 0.930 | +| Wood | 0.898 | 0.893 | +| Zipper | 0.925 | 0.958 | +| Average | 0.917 | 0.924 | + +### Pixel F1 Score + +| | ResNet-18 | Wide ResNet50 | +| ---------- | :-------: | :-----------: | +| Bottle | 0.751 | 0.789 | +| Cable | 0.661 | 0.674 | +| Capsule | 0.507 | 0.500 | +| Carpet | 0.549 | 0.578 | +| Grid | 0.316 | 0.280 | +| Hazelnut | 0.598 | 0.561 | +| Leather | 0.461 | 0.378 | +| Metal_nut | 0.819 | 0.874 | +| Pill | 0.689 | 0.679 | +| Screw | 0.212 | 0.301 | +| Tile | 0.740 | 0.768 | +| Toothbrush | 0.609 | 0.627 | +| Transistor | 0.570 | 0.666 | +| Wood | 0.564 | 0.627 | +| Zipper | 0.561 | 0.668 | +| Average | 0.574 | 0.598 | + +### Sample Results + +![Sample Result 1](/docs/source/images/cfa/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/cfa/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/cfa/results/2.png "Sample Result 3") + +## Reference + +[1] + +## Citation + +```tex +@article{lee2022cfa, + title={CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization}, + author={Lee, Sungwook and Lee, Seunghyun and Song, Byung Cheol}, + journal={arXiv preprint arXiv:2206.04325}, + year={2022} +} +``` diff --git a/anomalib/models/image/cfa/__init__.py b/anomalib/models/image/cfa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..def95441cbc24108247f60b917c01aba2f3abbb7 --- /dev/null +++ b/anomalib/models/image/cfa/__init__.py @@ -0,0 +1,13 @@ +"""Implementatation of the CFA Model. + +CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization + +Paper https://arxiv.org/abs/2206.04325 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Cfa + +__all__ = ["Cfa"] diff --git a/anomalib/models/image/cfa/anomaly_map.py b/anomalib/models/image/cfa/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..216c1b558958f1ac8968ecf5bf7b55ff73ba1515 --- /dev/null +++ b/anomalib/models/image/cfa/anomaly_map.py @@ -0,0 +1,86 @@ +"""Anomaly Map Generator for the CFA model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import GaussianBlur2d + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap.""" + + def __init__( + self, + num_nearest_neighbors: int, + sigma: int = 4, + ) -> None: + super().__init__() + self.num_nearest_neighbors = num_nearest_neighbors + self.sigma = sigma + + def compute_score(self, distance: torch.Tensor, scale: tuple[int, int]) -> torch.Tensor: + """Compute score based on the distance. + + Args: + distance (torch.Tensor): Distance tensor computed using target oriented + features. + scale (tuple[int, int]): Height and width of the largest feature + map. + + Returns: + Tensor: Score value. + """ + distance = torch.sqrt(distance) + distance = distance.topk(self.num_nearest_neighbors, largest=False).values # noqa: PD011 + distance = (F.softmin(distance, dim=-1)[:, :, 0]) * distance[:, :, 0] + distance = distance.unsqueeze(-1) + + score = rearrange(distance, "b (h w) c -> b c h w", h=scale[0], w=scale[1]) + return score.detach() + + def compute_anomaly_map( + self, + score: torch.Tensor, + image_size: tuple[int, int] | torch.Size | None = None, + ) -> torch.Tensor: + """Compute anomaly map based on the score. + + Args: + score (torch.Tensor): Score tensor. + image_size (tuple[int, int] | torch.Size | None, optional): Size of the input image. + + Returns: + Tensor: Anomaly map. + """ + anomaly_map = score.mean(dim=1, keepdim=True) + if image_size is not None: + anomaly_map = F.interpolate(anomaly_map, size=image_size, mode="bilinear", align_corners=False) + + gaussian_blur = GaussianBlur2d(sigma=self.sigma).to(score.device) + return gaussian_blur(anomaly_map) # pylint: disable=not-callable + + def forward(self, **kwargs) -> torch.Tensor: + """Return anomaly map. + + Raises: + ``distance`` and ``scale`` keys are not found. + + Returns: + Tensor: Anomaly heatmap. + """ + if not ("distance" in kwargs and "scale" in kwargs): + msg = f"Expected keys `distance` and `scale. Found {kwargs.keys()}" + raise ValueError(msg) + + distance: torch.Tensor = kwargs["distance"] + scale: tuple[int, int] = kwargs["scale"] + image_size: tuple[int, int] | torch.Size | None = kwargs.get("image_size", None) + + score = self.compute_score(distance=distance, scale=scale) + return self.compute_anomaly_map(score, image_size=image_size) diff --git a/anomalib/models/image/cfa/lightning_model.py b/anomalib/models/image/cfa/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..11d957cacf8da7e032ae59c4d29fbe4ee5e8c66d --- /dev/null +++ b/anomalib/models/image/cfa/lightning_model.py @@ -0,0 +1,147 @@ +"""Lightning Implementatation of the CFA Model. + +CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization + +Paper https://arxiv.org/abs/2206.04325 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import CfaLoss +from .torch_model import CfaModel + +logger = logging.getLogger(__name__) + +__all__ = ["Cfa"] + + +class Cfa(AnomalyModule): + """CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization. + + Args: + backbone (str): Backbone CNN network + Defaults to ``"wide_resnet50_2"``. + gamma_c (int, optional): gamma_c value from the paper. + Defaults to ``1``. + gamma_d (int, optional): gamma_d value from the paper. + Defaults to ``1``. + num_nearest_neighbors (int): Number of nearest neighbors. + Defaults to ``3``. + num_hard_negative_features (int): Number of hard negative features. + Defaults to ``3``. + radius (float): Radius of the hypersphere to search the soft boundary. + Defaults to ``1e-5``. + """ + + def __init__( + self, + backbone: str = "wide_resnet50_2", + gamma_c: int = 1, + gamma_d: int = 1, + num_nearest_neighbors: int = 3, + num_hard_negative_features: int = 3, + radius: float = 1e-5, + ) -> None: + super().__init__() + self.model: CfaModel = CfaModel( + backbone=backbone, + gamma_c=gamma_c, + gamma_d=gamma_d, + num_nearest_neighbors=num_nearest_neighbors, + num_hard_negative_features=num_hard_negative_features, + radius=radius, + ) + self.loss = CfaLoss( + num_nearest_neighbors=num_nearest_neighbors, + num_hard_negative_features=num_hard_negative_features, + radius=radius, + ) + + def on_train_start(self) -> None: + """Initialize the centroid for the memory bank computation.""" + self.model.initialize_centroid(data_loader=self.trainer.datamodule.train_dataloader()) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step for the CFA model. + + Args: + batch (dict[str, str | torch.Tensor]): Batch input. + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + STEP_OUTPUT: Loss value. + """ + del args, kwargs # These variables are not used. + + distance = self.model(batch["image"]) + loss = self.loss(distance) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step for the CFA model. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch. + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + dict: Anomaly map computed by the model. + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: + """Perform backward-pass for the CFA model. + + Args: + loss (torch.Tensor): Loss value. + *args: Arguments. + **kwargs: Keyword arguments. + """ + del args, kwargs # These variables are not used. + + # TODO(samet-akcay): Investigate why retain_graph is needed. + # CVS-122673 + loss.backward(retain_graph=True) + + @property + def trainer_arguments(self) -> dict[str, Any]: + """CFA specific trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers for the CFA Model. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + return torch.optim.AdamW( + params=self.model.parameters(), + lr=1e-3, + weight_decay=5e-4, + amsgrad=True, + ) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/cfa/loss.py b/anomalib/models/image/cfa/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f8d48eafedde48e181f22f95b07a57f2993ae7 --- /dev/null +++ b/anomalib/models/image/cfa/loss.py @@ -0,0 +1,44 @@ +"""Loss function for the Cfa Model Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + + +class CfaLoss(nn.Module): + """Cfa Loss. + + Args: + num_nearest_neighbors (int): Number of nearest neighbors. + num_hard_negative_features (int): Number of hard negative features. + radius (float): Radius of the hypersphere to search the soft boundary. + """ + + def __init__(self, num_nearest_neighbors: int, num_hard_negative_features: int, radius: float) -> None: + super().__init__() + self.num_nearest_neighbors = num_nearest_neighbors + self.num_hard_negative_features = num_hard_negative_features + self.radius = torch.ones(1, requires_grad=True) * radius + + def forward(self, distance: torch.Tensor) -> torch.Tensor: + """Compute the CFA loss. + + Args: + distance (torch.Tensor): Distance computed using target oriented features. + + Returns: + Tensor: CFA loss. + """ + num_neighbors = self.num_nearest_neighbors + self.num_hard_negative_features + distance = distance.topk(num_neighbors, largest=False).values # noqa: PD011 + + score = distance[:, :, : self.num_nearest_neighbors] - (self.radius**2).to(distance.device) + l_att = torch.mean(torch.max(torch.zeros_like(score), score)) + + score = (self.radius**2).to(distance.device) - distance[:, :, self.num_hard_negative_features :] + l_rep = torch.mean(torch.max(torch.zeros_like(score), score - 0.1)) + + return (l_att + l_rep) * 1000 diff --git a/anomalib/models/image/cfa/torch_model.py b/anomalib/models/image/cfa/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..25667d28c2f8966f52f6e71c174888127f79deba --- /dev/null +++ b/anomalib/models/image/cfa/torch_model.py @@ -0,0 +1,389 @@ +"""Torch Implementatation of the CFA Model. + +CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization + +Paper https://arxiv.org/abs/2206.04325 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torchvision +from einops import rearrange +from sklearn.cluster import KMeans +from torch import nn +from torch.fx.graph_module import GraphModule +from torch.nn import functional as F # noqa: N812 +from torch.nn.common_types import _size_2_t +from torch.utils.data import DataLoader +from torchvision.models.feature_extraction import create_feature_extractor +from tqdm import tqdm + +from anomalib.models.components import DynamicBufferMixin +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims + +from .anomaly_map import AnomalyMapGenerator + +SUPPORTED_BACKBONES = ("vgg19_bn", "resnet18", "wide_resnet50_2", "efficientnet_b5") + + +def get_return_nodes(backbone: str) -> list[str]: + """Get the return nodes for a given backbone. + + Args: + backbone (str): The name of the backbone. Must be one of + {"resnet18", "wide_resnet50_2", "vgg19_bn", "efficientnet_b5"}. + + Raises: + NotImplementedError: If the backbone is "efficientnet_b5". + ValueError: If the backbone is not one of the supported backbones. + + Returns: + list[str]: A list of return nodes for the given backbone. + """ + if backbone == "efficientnet_b5": + msg = "EfficientNet feature extractor has not implemented yet." + raise NotImplementedError(msg) + + return_nodes: list[str] + if backbone in ("resnet18", "wide_resnet50_2"): + return_nodes = ["layer1", "layer2", "layer3"] + elif backbone == "vgg19_bn": + return_nodes = ["features.25", "features.38", "features.52"] + else: + msg = f"Backbone {backbone} is not supported. Supported backbones are {SUPPORTED_BACKBONES}." + raise ValueError(msg) + return return_nodes + + +# TODO(samet-akcay): Replace this with the new torchfx feature extractor. +# CVS-122673 +def get_feature_extractor(backbone: str, return_nodes: list[str]) -> GraphModule: + """Get the feature extractor from the backbone CNN. + + Args: + backbone (str): Backbone CNN network + return_nodes (list[str]): A list of return nodes for the given backbone. + + Raises: + NotImplementedError: When the backbone is efficientnet_b5 + ValueError: When the backbone is not supported + + Returns: + GraphModule: Feature extractor. + """ + model = getattr(torchvision.models, backbone)(pretrained=True) + feature_extractor = create_feature_extractor(model=model, return_nodes=return_nodes) + feature_extractor.eval() + + return feature_extractor + + +class CfaModel(DynamicBufferMixin): + """Torch implementation of the CFA Model. + + Args: + backbone (str): Backbone CNN network. + gamma_c (int): gamma_c parameter from the paper. + gamma_d (int): gamma_d parameter from the paper. + num_nearest_neighbors (int): Number of nearest neighbors. + num_hard_negative_features (int): Number of hard negative features. + radius (float): Radius of the hypersphere to search the soft boundary. + """ + + def __init__( + self, + backbone: str, + gamma_c: int, + gamma_d: int, + num_nearest_neighbors: int, + num_hard_negative_features: int, + radius: float, + ) -> None: + super().__init__() + self.gamma_c = gamma_c + self.gamma_d = gamma_d + + self.num_nearest_neighbors = num_nearest_neighbors + self.num_hard_negative_features = num_hard_negative_features + + self.register_buffer("memory_bank", torch.tensor(0.0)) + self.memory_bank: torch.Tensor + + self.backbone = backbone + return_nodes = get_return_nodes(backbone) + self.feature_extractor = get_feature_extractor(backbone, return_nodes) + + self.descriptor = Descriptor(self.gamma_d, backbone) + self.radius = torch.ones(1, requires_grad=True) * radius + + self.anomaly_map_generator = AnomalyMapGenerator( + num_nearest_neighbors=num_nearest_neighbors, + ) + + def get_scale(self, input_size: tuple[int, int] | torch.Size) -> torch.Size: + """Get the scale of the feature map. + + Args: + input_size (tuple[int, int]): Input size of the image tensor. + """ + feature_map_metadata = dryrun_find_featuremap_dims( + feature_extractor=self.feature_extractor, + input_size=input_size, + layers=get_return_nodes(self.backbone), + ) + # Scale is to get the largest feature map dimensions of different layers + # of the feature extractor. In a typical feature extractor, the first + # layer has the highest resolution. + resolution = next(iter(feature_map_metadata.values()))["resolution"] + if isinstance(resolution, int): + scale = (resolution,) * 2 + elif isinstance(resolution, tuple): + scale = resolution + else: + msg = f"Unknown type {type(resolution)} for `resolution`. Expected types are either int or tuple[int, int]." + raise TypeError(msg) + return scale + + def initialize_centroid(self, data_loader: DataLoader) -> None: + """Initialize the Centroid of the Memory Bank. + + Args: + data_loader (DataLoader): Train Dataloader. + + Returns: + Tensor: Memory Bank. + """ + device = next(self.feature_extractor.parameters()).device + with torch.no_grad(): + for i, data in enumerate(tqdm(data_loader)): + batch = data["image"].to(device) + features = self.feature_extractor(batch) + features = list(features.values()) + target_features = self.descriptor(features) + self.memory_bank = ((self.memory_bank * i) + target_features.mean(dim=0, keepdim=True)) / (i + 1) + + self.memory_bank = rearrange(self.memory_bank, "b c h w -> (b h w) c") + + scale = self.get_scale(batch.shape[-2:]) + + if self.gamma_c > 1: + # TODO(samet-akcay): Create PyTorch KMeans class. + # CVS-122673 + k_means = KMeans(n_clusters=(scale[0] * scale[1]) // self.gamma_c, max_iter=3000) + cluster_centers = k_means.fit(self.memory_bank.cpu()).cluster_centers_ + self.memory_bank = torch.tensor(cluster_centers, requires_grad=False).to(device) + + self.memory_bank = rearrange(self.memory_bank, "h w -> w h") + + def compute_distance(self, target_oriented_features: torch.Tensor) -> torch.Tensor: + """Compute distance using target oriented features. + + Args: + target_oriented_features (torch.Tensor): Target oriented features computed + using the descriptor. + + Returns: + Tensor: Distance tensor. + """ + if target_oriented_features.ndim == 4: + target_oriented_features = rearrange(target_oriented_features, "b c h w -> b (h w) c") + + features = target_oriented_features.pow(2).sum(dim=2, keepdim=True) + centers = self.memory_bank.pow(2).sum(dim=0, keepdim=True).to(features.device) + f_c = 2 * torch.matmul(target_oriented_features, (self.memory_bank.to(features.device))) + return features + centers - f_c + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input_tensor (torch.Tensor): Input tensor. + + Raises: + ValueError: When the memory bank is not initialized. + + Returns: + Tensor: Loss or anomaly map depending on the train/eval mode. + """ + if self.memory_bank.ndim == 0: + msg = "Memory bank is not initialized. Run `initialize_centroid` method first." + raise ValueError(msg) + + self.feature_extractor.eval() + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + features = list(features.values()) + + target_features = self.descriptor(features) + distance = self.compute_distance(target_features) + + return ( + distance + if self.training + else self.anomaly_map_generator( + distance=distance, + scale=target_features.shape[-2:], + image_size=input_tensor.shape[-2:], + ) + ) + + +class Descriptor(nn.Module): + """Descriptor module.""" + + def __init__(self, gamma_d: int, backbone: str) -> None: + super().__init__() + + self.backbone = backbone + if self.backbone not in SUPPORTED_BACKBONES: + msg = f"Supported backbones are {SUPPORTED_BACKBONES}. Got {self.backbone} instead." + raise ValueError(msg) + + # TODO(samet-akcay): Automatically infer the number of dims + # CVS-122673 + backbone_dims = {"vgg19_bn": 1280, "resnet18": 448, "wide_resnet50_2": 1792, "efficientnet_b5": 568} + dim = backbone_dims[backbone] + out_channels = 2 * dim // gamma_d if backbone == "efficientnet_b5" else dim // gamma_d + + self.layer = CoordConv2d(in_channels=dim, out_channels=out_channels, kernel_size=1) + + def forward(self, features: list[torch.Tensor] | dict[str, torch.Tensor]) -> torch.Tensor: + """Forward pass.""" + if isinstance(features, dict): + features = list(features.values()) + + patch_features: torch.Tensor | None = None + for feature in features: + pooled_features = ( + F.avg_pool2d(feature, 3, 1, 1) / feature.size(1) + if self.backbone == "efficientnet_b5" + else F.avg_pool2d(feature, 3, 1, 1) + ) + patch_features = ( + pooled_features + if patch_features is None + else torch.cat((patch_features, F.interpolate(feature, patch_features.size(2), mode="bilinear")), dim=1) + ) + + return self.layer(patch_features) + + +class CoordConv2d(nn.Conv2d): + """CoordConv layer as in the paper. + + MIT License + Copyright (c) 2018 Walsvid + + Link to the paper: https://arxiv.org/abs/1807.03247 + Link to the PyTorch implementation: https://github.com/walsvid/CoordConv + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: str | _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # AddCoord layer. + self.add_coords = AddCoords(with_r) + + # Create conv layer on top of add_coords layer. + self.conv2d = nn.Conv2d( + in_channels=in_channels + 2 + int(with_r), # 2 for rank-2 tensor, 1 for r if with_r + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-renamed + """Forward pass. + + Args: + input_tensor (torch.Tensor): Input tensor. + + Returns: + Tensor: Output tensor after applying the CoordConv layer. + """ + out = self.add_coords(input_tensor) + return self.conv2d(out) + + +class AddCoords(nn.Module): + """Add coords to a tensor. + + MIT License + Copyright (c) 2018 Walsvid + + Link to the paper: https://arxiv.org/abs/1807.03247 + Link to the PyTorch implementation: https://github.com/walsvid/CoordConv + """ + + def __init__(self, with_r: bool = False) -> None: + super().__init__() + self.with_r = with_r + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input_tensor (torch.Tensor): Input tensor + + Returns: + Tensor: Output tensor with added coordinates. + """ + # NOTE: This is a modified version of the original implementation, + # which only supports rank 2 tensors. + batch, _, x_dim, y_dim = input_tensor.shape + xx_ones = torch.ones([1, 1, 1, y_dim], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, x_dim], dtype=torch.int32) + + xx_range = torch.arange(x_dim, dtype=torch.int32) + yy_range = torch.arange(y_dim, dtype=torch.int32) + xx_range = xx_range[None, None, :, None] + yy_range = yy_range[None, None, :, None] + + xx_channel = torch.matmul(xx_range, xx_ones) + yy_channel = torch.matmul(yy_range, yy_ones) + + # Transpose y + yy_channel = yy_channel.permute(0, 1, 3, 2) + + xx_channel = xx_channel.float() / (x_dim - 1) + yy_channel = yy_channel.float() / (y_dim - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch, 1, 1, 1).to(input_tensor.device) + yy_channel = yy_channel.repeat(batch, 1, 1, 1).to(input_tensor.device) + + out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr_channel = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) + out = torch.cat([out, rr_channel], dim=1) + + return out diff --git a/anomalib/models/image/cflow/README.md b/anomalib/models/image/cflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9345e3e9972767cd9c8f34b755b3ccee5d8030b --- /dev/null +++ b/anomalib/models/image/cflow/README.md @@ -0,0 +1,49 @@ +# CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows + +This is the implementation of the [CFLOW-AD](https://arxiv.org/pdf/2107.12571v1.pdf) paper. This code is modified form of the [official repository](https://github.com/gudovskiy/cflow-ad). + +Model Type: Segmentation + +## Description + +CFLOW model is based on a conditional normalizing flow framework adopted for anomaly detection with localization. It consists of a discriminatively pretrained encoder followed by a multi-scale generative decoders. The encoder extracts features with multi-scale pyramid pooling to capture both global and local semantic information with the growing from top to bottom receptive fields. Pooled features are processed by a set of decoders to explicitly estimate likelihood of the encoded features. The estimated multi-scale likelyhoods are upsampled to input size and added up to produce the anomaly map. + +## Architecture + +![CFlow Architecture](/docs/source/images/cflow/architecture.jpg "CFlow Architecture") + +## Usage + +`python tools/train.py --model cflow` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.962 | 0.986 | 0.962 | 1.0 | 0.999 | 0.993 | 1.0 | 0.893 | 0.945 | 1.0 | 0.995 | 0.924 | 0.908 | 0.897 | 0.943 | 0.984 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.971 | 0.986 | 0.968 | 0.993 | 0.968 | 0.924 | 0.981 | 0.955 | 0.988 | 0.990 | 0.982 | 0.983 | 0.979 | 0.985 | 0.897 | 0.980 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.944 | 0.972 | 0.932 | 1.000 | 0.988 | 0.967 | 1.000 | 0.832 | 0.939 | 1.000 | 0.979 | 0.924 | 0.971 | 0.870 | 0.818 | 0.967 | + +### Sample Results + +![Sample Result 1](/docs/source/images/cflow/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/cflow/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/cflow/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/cflow/__init__.py b/anomalib/models/image/cflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d4bfde71b0c6a21dbb33d72f84f42241734bb5 --- /dev/null +++ b/anomalib/models/image/cflow/__init__.py @@ -0,0 +1,8 @@ +"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Cflow + +__all__ = ["Cflow"] diff --git a/anomalib/models/image/cflow/anomaly_map.py b/anomalib/models/image/cflow/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..7c710d0e8299236c91b93f0a9f48d299eb95899c --- /dev/null +++ b/anomalib/models/image/cflow/anomaly_map.py @@ -0,0 +1,93 @@ +"""Anomaly Map Generator for CFlow model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Sequence +from typing import cast + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap.""" + + def __init__( + self, + pool_layers: Sequence[str], + ) -> None: + super().__init__() + self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) + self.pool_layers: Sequence[str] = pool_layers + + def compute_anomaly_map( + self, + distribution: list[torch.Tensor], + height: list[int], + width: list[int], + image_size: tuple[int, int] | torch.Size | None, + ) -> torch.Tensor: + """Compute the layer map based on likelihood estimation. + + Args: + distribution (list[torch.Tensor]): List of likelihoods for each layer. + height (list[int]): List of heights of the feature maps. + width (list[int]): List of widths of the feature maps. + image_size (tuple[int, int] | torch.Size | None): Size of the input image. + + Returns: + Final Anomaly Map + + """ + layer_maps: list[torch.Tensor] = [] + for layer_idx in range(len(self.pool_layers)): + layer_distribution = distribution[layer_idx].clone().detach() + # Normalize the likelihoods to (-Inf:0] and convert to probs in range [0:1] + layer_probabilities = torch.exp(layer_distribution - layer_distribution.max()) + layer_map = layer_probabilities.reshape(-1, height[layer_idx], width[layer_idx]) + # upsample + if image_size is not None: + layer_map = F.interpolate( + layer_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=True, + ).squeeze(1) + layer_maps.append(layer_map) + # score aggregation + score_map = torch.zeros_like(layer_maps[0]) + for layer_idx in range(len(self.pool_layers)): + score_map += layer_maps[layer_idx] + + # Invert probs to anomaly scores + return score_map.max() - score_map + + def forward(self, **kwargs: list[torch.Tensor] | list[int] | list[list]) -> torch.Tensor: + """Return anomaly_map. + + Expects `distribution`, `height` and 'width' keywords to be passed explicitly + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size), + >>> pool_layers=pool_layers) + >>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width) + + Raises: + ValueError: `distribution`, `height` and 'width' keys are not found + + Returns: + torch.Tensor: anomaly map + """ + if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs): + msg = f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}" + raise KeyError(msg) + + # placate mypy + distribution: list[torch.Tensor] = cast(list[torch.Tensor], kwargs["distribution"]) + height: list[int] = cast(list[int], kwargs["height"]) + width: list[int] = cast(list[int], kwargs["width"]) + image_size: tuple[int, int] | torch.Size | None = kwargs.get("image_size", None) + return self.compute_anomaly_map(distribution, height, width, image_size) diff --git a/anomalib/models/image/cflow/lightning_model.py b/anomalib/models/image/cflow/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9275cc75d401033d2c6ad312b9df31104af33c95 --- /dev/null +++ b/anomalib/models/image/cflow/lightning_model.py @@ -0,0 +1,213 @@ +"""Cflow. + +Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. + +For more details, see the paper: `Real-Time Unsupervised Anomaly Detection via +Conditional Normalizing Flows `_. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +__all__ = ["Cflow"] + +from collections.abc import Sequence +from typing import Any + +import einops +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim +from torch.nn import functional as F # noqa: N812 +from torch.optim import Optimizer + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .torch_model import CflowModel +from .utils import get_logp, positional_encoding_2d + + +class Cflow(AnomalyModule): + """PL Lightning Module for the CFLOW algorithm. + + Args: + backbone (str, optional): Backbone CNN architecture. + Defaults to ``"wide_resnet50_2"``. + layers (Sequence[str], optional): Layers to extract features from. + Defaults to ``("layer2", "layer3", "layer4")``. + pre_trained (bool, optional): Whether to use pre-trained weights. + Defaults to ``True``. + fiber_batch_size (int, optional): Fiber batch size. + Defaults to ``64``. + decoder (str, optional): Decoder architecture. + Defaults to ``"freia-cflow"``. + condition_vector (int, optional): Condition vector size. + Defaults to ``128``. + coupling_blocks (int, optional): Number of coupling blocks. + Defaults to ``8``. + clamp_alpha (float, optional): Clamping value for the alpha parameter. + Defaults to ``1.9``. + permute_soft (bool, optional): Whether to use soft permutation. + Defaults to ``False``. + lr (float, optional): Learning rate. + Defaults to ``0.0001``. + """ + + def __init__( + self, + backbone: str = "wide_resnet50_2", + layers: Sequence[str] = ("layer2", "layer3", "layer4"), + pre_trained: bool = True, + fiber_batch_size: int = 64, + decoder: str = "freia-cflow", + condition_vector: int = 128, + coupling_blocks: int = 8, + clamp_alpha: float = 1.9, + permute_soft: bool = False, + lr: float = 0.0001, + ) -> None: + super().__init__() + + self.model: CflowModel = CflowModel( + backbone=backbone, + pre_trained=pre_trained, + layers=layers, + fiber_batch_size=fiber_batch_size, + decoder=decoder, + condition_vector=condition_vector, + coupling_blocks=coupling_blocks, + clamp_alpha=clamp_alpha, + permute_soft=permute_soft, + ) + self.automatic_optimization = False + # TODO(ashwinvaidya17): LR should be part of optimizer in config.yaml since cflow has custom optimizer. + # CVS-122670 + self.learning_rate = lr + + def configure_optimizers(self) -> Optimizer: + """Configure optimizers for each decoder. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + decoders_parameters = [] + for decoder_idx in range(len(self.model.pool_layers)): + decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) + + return optim.Adam( + params=decoders_parameters, + lr=self.learning_rate, + ) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step of CFLOW. + + For each batch, decoder layers are trained with a dynamic fiber batch size. + Training step is performed manually as multiple training steps are involved + per batch of input images + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + Loss value for the batch + + """ + del args, kwargs # These variables are not used. + + opt = self.optimizers() + + images: torch.Tensor = batch["image"] + activation = self.model.encoder(images) + avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) + + height = [] + width = [] + for layer_idx, layer in enumerate(self.model.pool_layers): + encoder_activations = activation[layer].detach() # BxCxHxW + + batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() + image_size = im_height * im_width + embedding_length = batch_size * image_size # number of rows in the conditional vector + + height.append(im_height) + width.append(im_width) + # repeats positional encoding for the entire batch 1 C H W to B C H W + pos_encoding = einops.repeat( + positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), + "b c h w-> (tile b) c h w", + tile=batch_size, + ).to(images.device) + c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP + e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC + perm = torch.randperm(embedding_length) # BHW + decoder = self.model.decoders[layer_idx].to(images.device) + + fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches + if fiber_batches <= 0: + msg = "Make sure we have enough fibers, otherwise decrease N or batch-size!" + raise ValueError(msg) + + for batch_num in range(fiber_batches): # per-fiber processing + opt.zero_grad() + if batch_num < (fiber_batches - 1): + idx = torch.arange( + batch_num * self.model.fiber_batch_size, + (batch_num + 1) * self.model.fiber_batch_size, + ) + else: # When non-full batch is encountered batch_num * N will go out of bounds + idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) + # get random vectors + c_p = c_r[perm[idx]] # NxP + e_p = e_r[perm[idx]] # NxC + # decoder returns the transformed variable z and the log Jacobian determinant + p_u, log_jac_det = decoder(e_p, [c_p]) + # + decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) + log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim + loss = -F.logsigmoid(log_prob) + self.manual_backward(loss.mean()) + opt.step() + avg_loss += loss.sum() + + self.log("train_loss", avg_loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": avg_loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of CFLOW. + + Similar to the training step, encoder features + are extracted from the CNN for each batch, and anomaly + map is computed. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + Dictionary containing images, anomaly maps, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """C-FLOW specific trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/cflow/torch_model.py b/anomalib/models/image/cflow/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd7c564b39d9b82807a2d1cf42e185d459b8e61 --- /dev/null +++ b/anomalib/models/image/cflow/torch_model.py @@ -0,0 +1,154 @@ +"""PyTorch model for CFlow model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence + +import einops +import torch +from torch import nn + +from anomalib.models.components import TimmFeatureExtractor + +from .anomaly_map import AnomalyMapGenerator +from .utils import cflow_head, get_logp, positional_encoding_2d + + +class CflowModel(nn.Module): + """CFLOW: Conditional Normalizing Flows. + + Args: + backbone (str): Backbone CNN architecture. + layers (Sequence[str]): Layers to extract features from. + pre_trained (bool): Whether to use pre-trained weights. + Defaults to ``True``. + fiber_batch_size (int): Fiber batch size. + Defaults to ``64``. + decoder (str): Decoder architecture. + Defaults to ``"freia-cflow"``. + condition_vector (int): Condition vector size. + Defaults to ``128``. + coupling_blocks (int): Number of coupling blocks. + Defaults to ``8``. + clamp_alpha (float): Clamping value for the alpha parameter. + Defaults to ``1.9``. + permute_soft (bool): Whether to use soft permutation. + Defaults to ``False``. + """ + + def __init__( + self, + backbone: str, + layers: Sequence[str], + pre_trained: bool = True, + fiber_batch_size: int = 64, + decoder: str = "freia-cflow", + condition_vector: int = 128, + coupling_blocks: int = 8, + clamp_alpha: float = 1.9, + permute_soft: bool = False, + ) -> None: + super().__init__() + + self.backbone = backbone + self.fiber_batch_size = fiber_batch_size + self.condition_vector: int = condition_vector + self.dec_arch = decoder + self.pool_layers = layers + + self.encoder = TimmFeatureExtractor( + backbone=self.backbone, + layers=self.pool_layers, + pre_trained=pre_trained, + ).eval() + self.pool_dims = self.encoder.out_dims + self.decoders = nn.ModuleList( + [ + cflow_head( + condition_vector=self.condition_vector, + coupling_blocks=coupling_blocks, + clamp_alpha=clamp_alpha, + n_features=pool_dim, + permute_soft=permute_soft, + ) + for pool_dim in self.pool_dims + ], + ) + + # encoder model is fixed + for parameters in self.encoder.parameters(): + parameters.requires_grad = False + + self.anomaly_map_generator = AnomalyMapGenerator(pool_layers=self.pool_layers) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """Forward-pass images into the network to extract encoder features and compute probability. + + Args: + images: Batch of images. + + Returns: + Predicted anomaly maps. + + """ + self.encoder.eval() + self.decoders.eval() + with torch.no_grad(): + activation = self.encoder(images) + + distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers] + + height: list[int] = [] + width: list[int] = [] + for layer_idx, layer in enumerate(self.pool_layers): + encoder_activations = activation[layer] # BxCxHxW + + batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() + image_size = im_height * im_width + embedding_length = batch_size * image_size # number of rows in the conditional vector + + height.append(im_height) + width.append(im_width) + # repeats positional encoding for the entire batch 1 C H W to B C H W + pos_encoding = einops.repeat( + positional_encoding_2d(self.condition_vector, im_height, im_width).unsqueeze(0), + "b c h w-> (tile b) c h w", + tile=batch_size, + ).to(images.device) + c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP + e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC + decoder = self.decoders[layer_idx].to(images.device) + + # Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1. + # It is assumed that during training that E / N is a whole number as no errors were discovered during + # testing. In case it is observed in the future, we can use only this line and ensure that FIB is at + # least 1 or set `drop_last` in the dataloader to drop the last non-full batch. + fiber_batches = embedding_length // self.fiber_batch_size + int( + embedding_length % self.fiber_batch_size > 0, + ) + + for batch_num in range(fiber_batches): # per-fiber processing + if batch_num < (fiber_batches - 1): + idx = torch.arange(batch_num * self.fiber_batch_size, (batch_num + 1) * self.fiber_batch_size) + else: # When non-full batch is encountered batch_num+1 * N will go out of bounds + idx = torch.arange(batch_num * self.fiber_batch_size, embedding_length) + c_p = c_r[idx] # NxP + e_p = e_r[idx] # NxC + # decoder returns the transformed variable z and the log Jacobian determinant + with torch.no_grad(): + p_u, log_jac_det = decoder(e_p, [c_p]) + # + decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) + log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim + distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob)) + + output = self.anomaly_map_generator( + distribution=distribution, + height=height, + width=width, + image_size=images.shape[-2:], + ) + self.decoders.train() + + return output.to(images.device) diff --git a/anomalib/models/image/cflow/utils.py b/anomalib/models/image/cflow/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9db2bfe574863a982afa2c2d0cac94108568e8ce --- /dev/null +++ b/anomalib/models/image/cflow/utils.py @@ -0,0 +1,119 @@ +"""Helper functions for CFlow implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import math + +import numpy as np +import torch +from FrEIA.framework import SequenceINN +from torch import nn + +from anomalib.models.components.flow import AllInOneBlock + +logger = logging.getLogger(__name__) + + +def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor) -> torch.Tensor: + """Return the log likelihood estimation. + + Args: + dim_feature_vector (int): Dimensions of the condition vector + p_u (torch.Tensor): Random variable u + logdet_j (torch.Tensor): log of determinant of jacobian returned from the invertable decoder + + Returns: + Tensor: Log probability + """ + ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi)) + return dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j + + +def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor: + """Create embedding to store relative position of the feature vector using sine and cosine functions. + + Args: + condition_vector (int): Length of the condition vector + height (int): H of the positions + width (int): W of the positions + + Raises: + ValueError: Cannot generate encoding with conditional vector length not as multiple of 4 + + Returns: + Tensor: condition_vector x HEIGHT x WIDTH position matrix + """ + if condition_vector % 4 != 0: + msg = f"Cannot use sin/cos positional encoding with odd dimension (got dim={condition_vector})" + raise ValueError(msg) + pos_encoding = torch.zeros(condition_vector, height, width) + # Each dimension use half of condition_vector + condition_vector = condition_vector // 2 + div_term = torch.exp(torch.arange(0.0, condition_vector, 2) * -(math.log(1e4) / condition_vector)) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pos_encoding[0:condition_vector:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pos_encoding[1:condition_vector:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pos_encoding[condition_vector::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pos_encoding[condition_vector + 1 :: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + return pos_encoding + + +def subnet_fc(dims_in: int, dims_out: int) -> nn.Sequential: + """Subnetwork which predicts the affine coefficients. + + Args: + dims_in (int): input dimensions + dims_out (int): output dimensions + + Returns: + nn.Sequential: Feed-forward subnetwork + """ + return nn.Sequential(nn.Linear(dims_in, 2 * dims_in), nn.ReLU(), nn.Linear(2 * dims_in, dims_out)) + + +def cflow_head( + condition_vector: int, + coupling_blocks: int, + clamp_alpha: float, + n_features: int, + permute_soft: bool = False, +) -> SequenceINN: + """Create invertible decoder network. + + Args: + condition_vector (int): length of the condition vector + coupling_blocks (int): number of coupling blocks to build the decoder + clamp_alpha (float): clamping value to avoid exploding values + n_features (int): number of decoder features + permute_soft (bool): Whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, + or to use hard permutations instead. Note, ``permute_soft=True`` is very slow + when working with >512 dimensions. + + Returns: + SequenceINN: decoder network block + """ + coder = SequenceINN(n_features) + logger.info("CNF coder: %d", n_features) + for _ in range(coupling_blocks): + coder.append( + AllInOneBlock, + cond=0, + cond_shape=(condition_vector,), + subnet_constructor=subnet_fc, + affine_clamping=clamp_alpha, + global_affine_type="SOFTPLUS", + permute_soft=permute_soft, + ) + return coder diff --git a/anomalib/models/image/csflow/README.md b/anomalib/models/image/csflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ab979ddfee441b33332688821995c2ee1cc5f527 --- /dev/null +++ b/anomalib/models/image/csflow/README.md @@ -0,0 +1,96 @@ +# Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection + +This is the implementation of the [CS-Flow](https://arxiv.org/pdf/2110.02855.pdf) paper. This code is modified form of the [official repository](https://github.com/marco-rudolph/cs-flow). + +Model Type: Segmentation + +## Description + +The central idea of the paper is to handle fine-grained representations by incorporating global and local image context. This is done by taking multiple scales when extracting features and using a fully-convolutional normalizing flow to process the scales jointly. This can be seen in Figure 1. + +In each cross-scale coupling block, the input tensor is split into two parts across the channel dimension. Similar to RealNVP, each part is used to compute the scale and translate parameters for the affine transform. This is done with the help of cross-scale convolution layers as shown in Figure 2. These are point wise operations. As shown in the figure, the subnetworks are $r_1$ and $r_2$ and their outputs are $[s_1, t_1]$ and $[s_2, t_2]$. Then, the output of the coupling blocks are defined as. + +$$ +y_{out,2} = y_{in,2} \odot e^{\gamma_1s_1(y_{in,1}) + \gamma_1t_1(y_{in,1})}\\ +y_{out,1} = y_{in,1} \odot e^{\gamma_2s_2(y_{out,2}) + \gamma_2t_2(y_{out,2})} +$$ + +Here, $\gamma_1$ and $\gamma_2$ are learnable parameters for each block. + +Figure 3 shows the architecture of the subnetworks in detail. + +The anomaly score for each local position $(i,j)$ of the feature map $y^s$ at scale $s$ is computed by aggregating values along the channel dimension with $||z^s_{i,j}||^2_2$. Here $z$ is the latent variable and $z^s_{i,j}$ is the output of the final coupling block at scale $s$ for the local position $(i,j)$. Thus anomalies can be localized by marking image regions with high norm in output feature tensors $z^s$. + +## Architecture + +![CS-Flow Architecture](/docs/source/images/cs_flow/architecture1.jpg "CS-Flow Architecture") + +![Architecture of a Coupling Block](/docs/source/images/cs_flow/architecture2.jpg "Architecture of a Coupling Block") + +![Architecture of network predicting scale and shift parameters.](/docs/source/images/cs_flow/architecture3.jpg "Architecture of network predicting scale and shift parameters.") + +## Usage + +`python tools/train.py --model cs_flow` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +> The following table is generated with image size of 768 and generating the anomaly map from all the three scales unlike the paper. Initial experiments showed that the anomaly map from all the three scales gives better results than the one from the largest scale. + +### Image AUROC - 768 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ---: | ----: | ---------: | ---------: | -----: | +| EfficientNet-B5 | 0.987 | 1 | 0.989 | 1 | 0.998 | 0.998 | 1 | 0.996 | 0.981 | 0.994 | 1 | 0.98 | 0.95 | 0.919 | 1 | 0.999 | + +### Pixel AUROC - 768 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ---: | ----: | ---------: | ---------: | -----: | +| EfficientNet-B5 | 0.921 | 0.936 | 0.878 | 0.917 | 0.872 | 0.782 | 0.889 | 0.935 | 0.961 | 0.957 | 0.953 | 0.95 | 0.947 | 0.951 | 0.974 | 0.919 | + +### Pixel F1Score - 768 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ---: | ----: | -----: | ----: | ------: | -------: | --------: | ----: | ----: | ---------: | ---------: | -----: | +| EfficientNet-B5 | 0.33 | 0.219 | 0.104 | 0.144 | 0.41 | 0.211 | 0.357 | 0.375 | 0.333 | 0.375 | 0.689 | 0.458 | 0.094 | 0.342 | 0.597 | 0.238 | + +### Image F1 Score - 768 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ----: | ----: | ---------: | ---------: | -----: | +| EfficientNet-B5 | 0.985 | 1 | 0.991 | 1 | 0.988 | 0.992 | 1 | 0.973 | 0.977 | 0.979 | 0.995 | 0.975 | 0.975 | 0.952 | 0.988 | 0.996 | + +> For fair comparison with other algorithms, the following results are computed with image size of 256. + +### Image AUROC - 256 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ----: | ----: | ---------: | ---------: | -----: | +| EfficientNet-B5 | 0.972 | 0.995 | 0.982 | 1 | 0.972 | 0.988 | 1 | 0.97 | 0.907 | 0.995 | 0.972 | 0.953 | 0.896 | 0.969 | 0.987 | 0.987 | + +### Pixel AUROC - 256 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ---: | ----: | ---------: | ---------: | -----: | +| EfficientNet B5 | 0.845 | 0.847 | 0.746 | 0.851 | 0.775 | 0.677 | 0.853 | 0.863 | 0.882 | 0.895 | 0.932 | 0.92 | 0.779 | 0.892 | 0.96 | 0.803 | + +### Pixel F1Score - 256 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ----: | ----: | ---------: | ---------: | -----: | +| EfficientNet B5 | 0.231 | 0.108 | 0.069 | 0.048 | 0.306 | 0.127 | 0.303 | 0.21 | 0.165 | 0.215 | 0.659 | 0.412 | 0.017 | 0.214 | 0.513 | 0.106 | + +### Image F1 Score - 256 Image Size + +| | Average | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal_nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| :-------------- | ------: | -----: | ----: | ------: | ----: | ----: | -----: | ----: | ------: | -------: | --------: | ----: | ----: | ---------: | ---------: | -----: | +| EfficientNet B5 | 0.965 | 0.983 | 0.982 | 1 | 0.957 | 0.966 | 1 | 0.945 | 0.944 | 0.986 | 0.963 | 0.965 | 0.906 | 0.949 | 0.938 | 0.987 | + +### Sample Results + +### TODO: Add results diff --git a/anomalib/models/image/csflow/__init__.py b/anomalib/models/image/csflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f53d606823f5cb95ba5f40cd0fd73286e6570c08 --- /dev/null +++ b/anomalib/models/image/csflow/__init__.py @@ -0,0 +1,8 @@ +"""Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Csflow + +__all__ = ["Csflow"] diff --git a/anomalib/models/image/csflow/anomaly_map.py b/anomalib/models/image/csflow/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a98b766e0499042489019ff4ddd7eb4e26c93da7 --- /dev/null +++ b/anomalib/models/image/csflow/anomaly_map.py @@ -0,0 +1,68 @@ +"""Anomaly Map Generator for CS-Flow model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class AnomalyMapMode(str, Enum): + """Generate anomaly map from all the scales or the max.""" + + ALL = "all" + MAX = "max" + + +class AnomalyMapGenerator(nn.Module): + """Anomaly Map Generator for CS-Flow model. + + Args: + input_dims (tuple[int, int, int]): Input dimensions. + mode (AnomalyMapMode): Anomaly map mode. + Defaults to ``AnomalyMapMode.ALL``. + """ + + def __init__(self, input_dims: tuple[int, int, int], mode: AnomalyMapMode = AnomalyMapMode.ALL) -> None: + super().__init__() + self.mode = mode + self.input_dims = input_dims + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Get anomaly maps by taking mean of the z-distributions across channels. + + By default it computes anomaly maps for all the scales as it gave better performance on initial tests. + Use ``AnomalyMapMode.MAX`` for the largest scale as mentioned in the paper. + + Args: + inputs (torch.Tensor): z-distributions for the three scales. + mode (AnomalyMapMode): Anomaly map mode. + + Returns: + Tensor: Anomaly maps. + """ + anomaly_map: torch.Tensor + if self.mode == AnomalyMapMode.ALL: + anomaly_map = torch.ones(inputs[0].shape[0], 1, *self.input_dims[1:]).to(inputs[0].device) + for z_dist in inputs: + mean_z = (z_dist**2).mean(dim=1, keepdim=True) + anomaly_map *= F.interpolate( + mean_z, + size=self.input_dims[1:], + mode="bilinear", + align_corners=False, + ) + else: + mean_z = (inputs[0] ** 2).mean(dim=1, keepdim=True) + anomaly_map = F.interpolate( + mean_z, + size=self.input_dims[1:], + mode="bilinear", + align_corners=False, + ) + + return anomaly_map diff --git a/anomalib/models/image/csflow/lightning_model.py b/anomalib/models/image/csflow/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a759000aa2d411e15f97c176a676290ce80c1a16 --- /dev/null +++ b/anomalib/models/image/csflow/lightning_model.py @@ -0,0 +1,134 @@ +"""Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection. + +https://arxiv.org/pdf/2110.02855.pdf +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import CsFlowLoss +from .torch_model import CsFlowModel + +logger = logging.getLogger(__name__) + +__all__ = ["Csflow"] + + +class Csflow(AnomalyModule): + """Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection. + + Args: + n_coupling_blocks (int): Number of coupling blocks in the model. + Defaults to ``4``. + cross_conv_hidden_channels (int): Number of hidden channels in the cross convolution. + Defaults to ``1024``. + clamp (int): Clamp value for glow layer. + Defaults to ``3``. + num_channels (int): Number of channels in the model. + Defaults to ``3``. + """ + + def __init__( + self, + cross_conv_hidden_channels: int = 1024, + n_coupling_blocks: int = 4, + clamp: int = 3, + num_channels: int = 3, + ) -> None: + super().__init__() + + self.cross_conv_hidden_channels = cross_conv_hidden_channels + self.n_coupling_blocks = n_coupling_blocks + self.clamp = clamp + self.num_channels = num_channels + + self.loss = CsFlowLoss() + + self.model: CsFlowModel + + def _setup(self) -> None: + if self.input_size is None: + msg = "CsFlow needs input size to build torch model." + raise ValueError(msg) + + self.model = CsFlowModel( + input_size=self.input_size, + cross_conv_hidden_channels=self.cross_conv_hidden_channels, + n_coupling_blocks=self.n_coupling_blocks, + clamp=self.clamp, + num_channels=self.num_channels, + ) + self.model.feature_extractor.eval() + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step of CS-Flow. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Loss value + """ + del args, kwargs # These variables are not used. + + z_dist, jacobians = self.model(batch["image"]) + loss = self.loss(z_dist, jacobians) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step for CS Flow. + + Args: + batch (torch.Tensor): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + dict[str, torch.Tensor]: Dictionary containing the anomaly map, scores, etc. + """ + del args, kwargs # These variables are not used. + + anomaly_maps, anomaly_scores = self.model(batch["image"]) + batch["anomaly_maps"] = anomaly_maps + batch["pred_scores"] = anomaly_scores + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """CS-Flow-specific trainer arguments.""" + return {"gradient_clip_val": 1, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers. + + Returns: + Optimizer: Adam optimizer + """ + return torch.optim.Adam( + self.parameters(), + lr=2e-4, + eps=1e-04, + weight_decay=1e-5, + betas=(0.5, 0.9), + ) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/csflow/loss.py b/anomalib/models/image/csflow/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7809be7a77ada2600c16a17ceff50ea557057d --- /dev/null +++ b/anomalib/models/image/csflow/loss.py @@ -0,0 +1,24 @@ +"""Loss function for the CS-Flow Model Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn + + +class CsFlowLoss(nn.Module): + """Loss function for the CS-Flow Model Implementation.""" + + def forward(self, z_dist: torch.Tensor, jacobians: torch.Tensor) -> torch.Tensor: + """Compute the loss CS-Flow. + + Args: + z_dist (torch.Tensor): Latent space image mappings from NF. + jacobians (torch.Tensor): Jacobians of the distribution + + Returns: + Loss value + """ + z_dist = torch.cat([z_dist[i].reshape(z_dist[i].shape[0], -1) for i in range(len(z_dist))], dim=1) + return torch.mean(0.5 * torch.sum(z_dist**2, dim=(1,)) - jacobians) / z_dist.shape[1] diff --git a/anomalib/models/image/csflow/torch_model.py b/anomalib/models/image/csflow/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13fa98f5ac562408f6a802b71725c40eb3c56106 --- /dev/null +++ b/anomalib/models/image/csflow/torch_model.py @@ -0,0 +1,606 @@ +"""PyTorch model for CS-Flow implementation.""" + + +# Original Code +# Copyright (c) 2021 marco-rudolph +# https://github.com/marco-rudolph/cs-flow +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from math import exp + +import numpy as np +import torch +from FrEIA.framework import GraphINN, InputNode, Node, OutputNode +from FrEIA.modules import InvertibleModule +from torch import nn +from torch.nn import functional as F # noqa: N812 +from torchvision.models.efficientnet import EfficientNet_B5_Weights + +from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + +from .anomaly_map import AnomalyMapGenerator, AnomalyMapMode + + +class CrossConvolutions(nn.Module): + """Cross convolution for the three scales. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels in the hidden convolution and the upscaling layers. + channels_hidden (int, optional): Number of input channels in the hidden convolution layers. + Defaults to ``512``. + kernel_size (int, optional): Kernel size of the convolution layers. + Defaults to ``3``. + leaky_slope (float, optional): Slope of the leaky ReLU activation. + Defaults to ``0.1``. + batch_norm (bool, optional): Whether to use batch normalization. + Defaults to ``False``. + use_gamma (bool, optional): Whether to use gamma parameters for the cross convolutions. + Defaults to ``True``. + """ + + def __init__( + self, + in_channels: int, + channels: int, + channels_hidden: int = 512, + kernel_size: int = 3, + leaky_slope: float = 0.1, + batch_norm: bool = False, + use_gamma: bool = True, + ) -> None: + super().__init__() + + pad = kernel_size // 2 + self.leaky_slope = leaky_slope + pad_mode = "zeros" + self.use_gamma = use_gamma + self.gamma0 = nn.Parameter(torch.zeros(1)) + self.gamma1 = nn.Parameter(torch.zeros(1)) + self.gamma2 = nn.Parameter(torch.zeros(1)) + + self.conv_scale0_0 = nn.Conv2d( + in_channels, + channels_hidden, + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + padding_mode=pad_mode, + ) + + self.conv_scale1_0 = nn.Conv2d( + in_channels, + channels_hidden, + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + padding_mode=pad_mode, + ) + self.conv_scale2_0 = nn.Conv2d( + in_channels, + channels_hidden, + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + padding_mode=pad_mode, + ) + self.conv_scale0_1 = nn.Conv2d( + channels_hidden * 1, + channels, # + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + padding_mode=pad_mode, + dilation=1, + ) + self.conv_scale1_1 = nn.Conv2d( + channels_hidden * 1, + channels, # + kernel_size=kernel_size, + padding=pad * 1, + bias=not batch_norm, + padding_mode=pad_mode, + dilation=1, + ) + self.conv_scale2_1 = nn.Conv2d( + channels_hidden * 1, + channels, # + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + padding_mode=pad_mode, + ) + + self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + + self.up_conv10 = nn.Conv2d( + channels_hidden, + channels, + kernel_size=kernel_size, + padding=pad, + bias=True, + padding_mode=pad_mode, + ) + + self.up_conv21 = nn.Conv2d( + channels_hidden, + channels, + kernel_size=kernel_size, + padding=pad, + bias=True, + padding_mode=pad_mode, + ) + + self.down_conv01 = nn.Conv2d( + channels_hidden, + channels, + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + stride=2, + padding_mode=pad_mode, + dilation=1, + ) + + self.down_conv12 = nn.Conv2d( + channels_hidden, + channels, + kernel_size=kernel_size, + padding=pad, + bias=not batch_norm, + stride=2, + padding_mode=pad_mode, + dilation=1, + ) + + self.leaky_relu = nn.LeakyReLU(self.leaky_slope) + + def forward(self, scale0: int, scale1: int, scale2: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply the cross convolution to the three scales. + + This block is represented in figure 4 of the paper. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tensors indicating scale and transform parameters + as a single tensor for each scale. The scale parameters are the first part across channel dimension + and the transform parameters are the second. + """ + # Increase the number of channels to hidden channel length via convolutions and apply leaky ReLU. + out0 = self.conv_scale0_0(scale0) + out1 = self.conv_scale1_0(scale1) + out2 = self.conv_scale2_0(scale2) + + lr0 = self.leaky_relu(out0) + lr1 = self.leaky_relu(out1) + lr3 = self.leaky_relu(out2) + + # Decrease the number of channels to scale and transform split length. + out0 = self.conv_scale0_1(lr0) + out1 = self.conv_scale1_1(lr1) + out2 = self.conv_scale2_1(lr3) + + # Upsample the smaller scales. + y1_up = self.up_conv10(self.upsample(lr1)) + y2_up = self.up_conv21(self.upsample(lr3)) + + # Downsample the larger scales. + y0_down = self.down_conv01(lr0) + y1_down = self.down_conv12(lr1) + + # Do element-wise sum on cross-scale outputs. + out0 = out0 + y1_up + out1 = out1 + y0_down + y2_up + out2 = out2 + y1_down + + if self.use_gamma: + out0 = out0 * self.gamma0 + out1 = out1 * self.gamma1 + out2 = out2 * self.gamma2 + # even channel split is performed outside this block + return out0, out1, out2 + + +class ParallelPermute(InvertibleModule): + """Permutes input vector in a random but fixed way. + + Args: + dim (list[tuple[int]]): Dimension of the input vector. + seed (float | None=None): Seed for the random permutation. + """ + + def __init__(self, dims_in: list[tuple[int]], seed: int | None = None) -> None: + super().__init__(dims_in) + self.n_inputs: int = len(dims_in) + self.in_channels = [dims_in[i][0] for i in range(self.n_inputs)] + self.seed = seed + + perm, perm_inv = self.get_random_perm(0) + self.perm = [perm] # stores the random order of channels + self.perm_inv = [perm_inv] # stores the inverse mapping to recover the original order of channels + + for i in range(1, self.n_inputs): + perm, perm_inv = self.get_random_perm(i) + self.perm.append(perm) + self.perm_inv.append(perm_inv) + + def get_random_perm(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: + """Return a random permutation of the channels for each input. + + Args: + index (int): index of the input + + Returns: + tuple[torch.Tensor, torch.Tensor]: permutation and inverse permutation + """ + perm = np.random.default_rng(self.seed).permutation(self.in_channels[index]) + perm_inv = np.zeros_like(perm) + for idx, permutation in enumerate(perm): + perm_inv[permutation] = idx + + perm = torch.LongTensor(perm) + perm_inv = torch.LongTensor(perm_inv) + return perm, perm_inv + + # pylint: disable=unused-argument + def forward( + self, + input_tensor: list[torch.Tensor], + rev: bool = False, + jac: bool = True, + ) -> tuple[list[torch.Tensor], float]: + """Apply the permutation to the input. + + Args: + input_tensor: list of input tensors + rev: if True, applies the inverse permutation + Defaults to ``False``. + jac: (unused) if True, computes the log determinant of the Jacobian + Defaults to ``True``. + + Returns: + tuple[torch.Tensor, torch.Tensor]: output tensor and log determinant of the Jacobian + """ + del jac # Unused argument. + + if not rev: + return [input_tensor[i][:, self.perm[i]] for i in range(self.n_inputs)], 0.0 + + return [input_tensor[i][:, self.perm_inv[i]] for i in range(self.n_inputs)], 0.0 + + def output_dims(self, input_dims: list[tuple[int]]) -> list[tuple[int]]: + """Return the output dimensions of the module.""" + return input_dims + + +class ParallelGlowCouplingLayer(InvertibleModule): + """Coupling block that follows the GLOW design but is applied to all the scales in parallel. + + Args: + dims_in (list[tuple[int]]): list of dimensions of the input tensors + subnet_args (dict): arguments of the subnet + clamp (float): clamp value for the output of the subnet + Defaults to ``5.0``. + """ + + def __init__(self, dims_in: list[tuple[int]], subnet_args: dict, clamp: float = 5.0) -> None: + super().__init__(dims_in) + channels = dims_in[0][0] + self.ndims = len(dims_in[0]) + + self.split_len1 = channels // 2 + self.split_len2 = channels - channels // 2 + + self.clamp = clamp + + self.max_s = exp(clamp) + self.min_s = exp(-clamp) + + self.cross_convolution1 = CrossConvolutions(self.split_len1, self.split_len2 * 2, **subnet_args) + self.cross_convolution2 = CrossConvolutions(self.split_len2, self.split_len1 * 2, **subnet_args) + + def exp(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Exponentiates the input and, optionally, clamps it to avoid numerical issues.""" + if self.clamp > 0: + return torch.exp(self.log_e(input_tensor)) + return torch.exp(input_tensor) + + def log_e(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Return log of input. And optionally clamped to avoid numerical issues.""" + if self.clamp > 0: + return self.clamp * 0.636 * torch.atan(input_tensor / self.clamp) + return input_tensor + + def forward( + self, + input_tensor: list[torch.Tensor], + rev: bool = False, + jac: bool = True, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """Apply GLOW coupling for the three scales.""" + del jac # Unused argument. + + # Even channel split. The two splits are used by cross-scale convolution to compute scale and transform + # parameters. + x01, x02 = ( + input_tensor[0].narrow(1, 0, self.split_len1), + input_tensor[0].narrow(1, self.split_len1, self.split_len2), + ) + x11, x12 = ( + input_tensor[1].narrow(1, 0, self.split_len1), + input_tensor[1].narrow(1, self.split_len1, self.split_len2), + ) + x21, x22 = ( + input_tensor[2].narrow(1, 0, self.split_len1), + input_tensor[2].narrow(1, self.split_len1, self.split_len2), + ) + + if not rev: + # Outputs of cross convolutions at three scales + r02, r12, r22 = self.cross_convolution2(x02, x12, x22) + + # Scale and transform parameters are obtained by splitting the output of cross convolutions. + s02, t02 = r02[:, : self.split_len1], r02[:, self.split_len1 :] + s12, t12 = r12[:, : self.split_len1], r12[:, self.split_len1 :] + s22, t22 = r22[:, : self.split_len1], r22[:, self.split_len1 :] + + # apply element wise affine transformation on the first part + y01 = self.exp(s02) * x01 + t02 + y11 = self.exp(s12) * x11 + t12 + y21 = self.exp(s22) * x21 + t22 + + r01, r11, r21 = self.cross_convolution1(y01, y11, y21) + + s01, t01 = r01[:, : self.split_len2], r01[:, self.split_len2 :] + s11, t11 = r11[:, : self.split_len2], r11[:, self.split_len2 :] + s21, t21 = r21[:, : self.split_len2], r21[:, self.split_len2 :] + + # apply element wise affine transformation on the second part + y02 = self.exp(s01) * x02 + t01 + y12 = self.exp(s11) * x12 + t11 + y22 = self.exp(s21) * x22 + t21 + + else: # names of x and y are swapped! + # Inverse affine transformation at three scales. + r01, r11, r21 = self.cross_convolution1(x01, x11, x21) + + s01, t01 = r01[:, : self.split_len2], r01[:, self.split_len2 :] + s11, t11 = r11[:, : self.split_len2], r11[:, self.split_len2 :] + s21, t21 = r21[:, : self.split_len2], r21[:, self.split_len2 :] + + y02 = (x02 - t01) / self.exp(s01) + y12 = (x12 - t11) / self.exp(s11) + y22 = (x22 - t21) / self.exp(s21) + + r02, r12, r22 = self.cross_convolution2(y02, y12, y22) + + s02, t02 = r02[:, : self.split_len2], r01[:, self.split_len2 :] + s12, t12 = r12[:, : self.split_len2], r11[:, self.split_len2 :] + s22, t22 = r22[:, : self.split_len2], r21[:, self.split_len2 :] + + y01 = (x01 - t02) / self.exp(s02) + y11 = (x11 - t12) / self.exp(s12) + y21 = (x21 - t22) / self.exp(s22) + + # Concatenate the outputs of the three scales to get three transformed outputs that have the same shape as the + # inputs. + z_dist0 = torch.cat((y01, y02), 1) + z_dist1 = torch.cat((y11, y12), 1) + z_dist2 = torch.cat((y21, y22), 1) + + z_dist0 = torch.clamp(z_dist0, -1e6, 1e6) + z_dist1 = torch.clamp(z_dist1, -1e6, 1e6) + z_dist2 = torch.clamp(z_dist2, -1e6, 1e6) + + jac0 = torch.sum(self.log_e(s01), dim=(1, 2, 3)) + torch.sum(self.log_e(s02), dim=(1, 2, 3)) + jac1 = torch.sum(self.log_e(s11), dim=(1, 2, 3)) + torch.sum(self.log_e(s12), dim=(1, 2, 3)) + jac2 = torch.sum(self.log_e(s21), dim=(1, 2, 3)) + torch.sum(self.log_e(s22), dim=(1, 2, 3)) + + # Since Jacobians are only used for computing loss and summed in the loss, the idea is to sum them here + return [z_dist0, z_dist1, z_dist2], torch.stack([jac0, jac1, jac2], dim=1).sum() + + def output_dims(self, input_dims: list[tuple[int]]) -> list[tuple[int]]: + """Output dimensions of the module.""" + return input_dims + + +class CrossScaleFlow(nn.Module): + """Cross scale coupling layer. + + Args: + input_dims (tuple[int, int, int]): Input dimensions of the module. + n_coupling_blocks (int): Number of coupling blocks. + clamp (float): Clamp value for the inputs. + corss_conv_hidden_channels (int): Number of hidden channels in the cross convolution. + """ + + def __init__( + self, + input_dims: tuple[int, int, int], + n_coupling_blocks: int, + clamp: float, + cross_conv_hidden_channels: int, + ) -> None: + super().__init__() + self.input_dims = input_dims + self.n_coupling_blocks = n_coupling_blocks + self.kernel_sizes = [3] * (n_coupling_blocks - 1) + [5] + self.clamp = clamp + self.cross_conv_hidden_channels = cross_conv_hidden_channels + self.graph = self._create_graph() + + def _create_graph(self) -> GraphINN: + nodes: list[Node] = [] + # 304 is the number of features extracted from EfficientNet-B5 feature extractor + input_nodes = [ + InputNode(304, (self.input_dims[1] // 32), (self.input_dims[2] // 32), name="input"), + InputNode(304, (self.input_dims[1] // 64), (self.input_dims[2] // 64), name="input2"), + InputNode(304, (self.input_dims[1] // 128), (self.input_dims[2] // 128), name="input3"), + ] + nodes.extend(input_nodes) + + for coupling_block in range(self.n_coupling_blocks): + if coupling_block == 0: + node_to_permute = [nodes[-3].out0, nodes[-2].out0, nodes[-1].out0] + else: + node_to_permute = [nodes[-1].out0, nodes[-1].out1, nodes[-1].out2] + + permute_node = Node( + inputs=node_to_permute, + module_type=ParallelPermute, + module_args={"seed": coupling_block}, + name=f"permute_{coupling_block}", + ) + nodes.extend([permute_node]) + coupling_layer_node = Node( + inputs=[nodes[-1].out0, nodes[-1].out1, nodes[-1].out2], + module_type=ParallelGlowCouplingLayer, + module_args={ + "clamp": self.clamp, + "subnet_args": { + "channels_hidden": self.cross_conv_hidden_channels, + "kernel_size": self.kernel_sizes[coupling_block], + }, + }, + name=f"fc1_{coupling_block}", + ) + nodes.extend([coupling_layer_node]) + + output_nodes = [ + OutputNode([nodes[-1].out0], name="output_end0"), + OutputNode([nodes[-1].out1], name="output_end1"), + OutputNode([nodes[-1].out2], name="output_end2"), + ] + nodes.extend(output_nodes) + return GraphINN(nodes) + + def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + inputs (torch.Tensor): Input tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Output tensor and log determinant of Jacobian. + """ + return self.graph(inputs) + + +class MultiScaleFeatureExtractor(nn.Module): + """Multi-scale feature extractor. + + Uses 36th layer of EfficientNet-B5 to extract features. + + Args: + n_scales (int): Number of scales for input image. + input_size (tuple[int, int]): Size of input image. + """ + + def __init__(self, n_scales: int, input_size: tuple[int, int]) -> None: + super().__init__() + + self.n_scales = n_scales + self.input_size = input_size + self.feature_extractor = TorchFXFeatureExtractor( + backbone="efficientnet_b5", + weights=EfficientNet_B5_Weights.DEFAULT, + return_nodes=["features.6.8"], + ) + + def forward(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: + """Extract features at three scales. + + Args: + input_tensor (torch.Tensor): Input images. + + Returns: + list[torch.Tensor]: List of tensors containing features at three scales. + """ + output = [] + for scale in range(self.n_scales): + feat_s = ( + F.interpolate( + input_tensor, + size=(self.input_size[0] // (2**scale), self.input_size[1] // (2**scale)), + ) + if scale > 0 + else input_tensor + ) + feat_s = self.feature_extractor(feat_s)["features.6.8"] + + output.append(feat_s) + return output + + +class CsFlowModel(nn.Module): + """CS Flow Module. + + Args: + input_size (tuple[int, int]): Input image size. + cross_conv_hidden_channels (int): Number of hidden channels in the cross convolution. + n_coupling_blocks (int): Number of coupling blocks. + Defaults to ``4``. + clamp (float): Clamp value for the coupling blocks. + Defaults to ``3``. + num_channels (int): Number of channels in the input image. + Defaults to ``3``. + """ + + def __init__( + self, + input_size: tuple[int, int], + cross_conv_hidden_channels: int, + n_coupling_blocks: int = 4, + clamp: int = 3, + num_channels: int = 3, + ) -> None: + super().__init__() + self.input_dims = (num_channels, *input_size) + self.clamp = clamp + self.cross_conv_hidden_channels = cross_conv_hidden_channels + self.feature_extractor = MultiScaleFeatureExtractor(n_scales=3, input_size=input_size).eval() + self.graph = CrossScaleFlow( + input_dims=self.input_dims, + n_coupling_blocks=n_coupling_blocks, + clamp=clamp, + cross_conv_hidden_channels=cross_conv_hidden_channels, + ) + self.anomaly_map_generator = AnomalyMapGenerator(input_dims=self.input_dims, mode=AnomalyMapMode.ALL) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward method of the model. + + Args: + images (torch.Tensor): Input images. + + Returns: + tuple[torch.Tensor, torch.Tensor]: During training: tuple containing the z_distribution for three scales + and the sum of log determinant of the Jacobian. During evaluation: tuple containing anomaly maps + and anomaly scores + """ + features = self.feature_extractor(images) + if self.training: + output = self.graph(features) + else: + z_dist, _ = self.graph(features) # Ignore Jacobians + anomaly_scores = self._compute_anomaly_scores(z_dist) + anomaly_maps = self.anomaly_map_generator(z_dist) + output = anomaly_maps, anomaly_scores + return output + + def _compute_anomaly_scores(self, z_dists: torch.Tensor) -> torch.Tensor: + """Get anomaly scores from the latent distribution. + + Args: + z_dists (torch.Tensor): Latent distribution. + + Returns: + Tensor: Anomaly scores. + """ + # z_dist is a 3 length list of tensors with shape b x 304 x fx x fy + flat_maps = [z_dist.reshape(z_dist.shape[0], -1) for z_dist in z_dists] + flat_maps_tensor = torch.cat(flat_maps, dim=1) + return torch.mean(flat_maps_tensor**2 / 2, dim=1) diff --git a/anomalib/models/image/dfkde/README.md b/anomalib/models/image/dfkde/README.md new file mode 100644 index 0000000000000000000000000000000000000000..17e0d5483ff31e2708c8ff2fd9435b20abbde1d6 --- /dev/null +++ b/anomalib/models/image/dfkde/README.md @@ -0,0 +1,39 @@ +# Deep Feature Kernel Density Estimation + +Model Type: Classification + +## Description + +Fast anomaly classification algorithm that consists of a deep feature extraction stage followed by anomaly classification stage consisting of PCA and Gaussian Kernel Density Estimation. + +### Feature Extraction + +Features are extracted by feeding the images through a ResNet50 backbone, which was pre-trained on ImageNet. The output of the penultimate layer (average pooling layer) of the network is used to obtain a semantic feature vector with a fixed length of 2048. + +### Anomaly Detection + +In the anomaly classification stage, the features are first reduced to the first 16 principal components. Gaussian Kernel Density is then used to obtain an estimate of the probability density of new examples, based on the collection of training features obtained during the training phase. + +## Usage + +`python tools/train.py --model dfkde` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 | +| Wide ResNet-50 | 0.774 | 0.708 | 0.422 | 0.905 | 0.959 | 0.903 | 0.936 | 0.746 | 0.853 | 0.736 | 0.687 | 0.749 | 0.574 | 0.697 | 0.843 | 0.892 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.872 | 0.864 | 0.844 | 0.854 | 0.960 | 0.898 | 0.942 | 0.793 | 0.908 | 0.827 | 0.894 | 0.916 | 0.859 | 0.853 | 0.756 | 0.916 | +| Wide ResNet-50 | 0.875 | 0.907 | 0.844 | 0.905 | 0.945 | 0.914 | 0.946 | 0.790 | 0.914 | 0.817 | 0.894 | 0.922 | 0.855 | 0.845 | 0.722 | 0.910 | diff --git a/anomalib/models/image/dfkde/__init__.py b/anomalib/models/image/dfkde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9930fcea7175a122397ee146619b26d079a192d8 --- /dev/null +++ b/anomalib/models/image/dfkde/__init__.py @@ -0,0 +1,8 @@ +"""Deep Feature Kernel Density Estimation model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Dfkde + +__all__ = ["Dfkde"] diff --git a/anomalib/models/image/dfkde/lightning_model.py b/anomalib/models/image/dfkde/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9662d5b7ae80a9cdae500e339b50a84b0bbcdf23 --- /dev/null +++ b/anomalib/models/image/dfkde/lightning_model.py @@ -0,0 +1,121 @@ +"""DFKDE: Deep Feature Kernel Density Estimation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections.abc import Sequence +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin +from anomalib.models.components.classification import FeatureScalingMethod + +from .torch_model import DfkdeModel + +logger = logging.getLogger(__name__) + + +class Dfkde(MemoryBankMixin, AnomalyModule): + """DFKDE: Deep Feature Kernel Density Estimation. + + Args: + backbone (str): Pre-trained model backbone. + Defaults to ``"resnet18"``. + layers (Sequence[str], optional): Layers to extract features from. + Defaults to ``("layer4",)``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + n_pca_components (int, optional): Number of PCA components. + Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod, optional): Feature scaling method. + Defaults to ``FeatureScalingMethod.SCALE``. + max_training_points (int, optional): Number of training points to fit the KDE model. + Defaults to ``40000``. + """ + + def __init__( + self, + backbone: str = "resnet18", + layers: Sequence[str] = ("layer4",), + pre_trained: bool = True, + n_pca_components: int = 16, + feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE, + max_training_points: int = 40000, + ) -> None: + super().__init__() + + self.model = DfkdeModel( + layers=layers, + backbone=backbone, + pre_trained=pre_trained, + n_pca_components=n_pca_components, + feature_scaling_method=feature_scaling_method, + max_training_points=max_training_points, + ) + + self.embeddings: list[torch.Tensor] = [] + + @staticmethod + def configure_optimizers() -> None: # pylint: disable=arguments-differ + """DFKDE doesn't require optimization, therefore returns no optimizers.""" + return + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + """Perform the training step of DFKDE. For each batch, features are extracted from the CNN. + + Args: + batch (batch: dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Deep CNN features. + """ + del args, kwargs # These variables are not used. + + embedding = self.model(batch["image"]) + self.embeddings.append(embedding) + + def fit(self) -> None: + """Fit a KDE Model to the embedding collected from the training set.""" + embeddings = torch.vstack(self.embeddings) + + logger.info("Fitting a KDE model to the embedding collected from the training set.") + self.model.classifier.fit(embeddings) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of DFKDE. + + Similar to the training step, features are extracted from the CNN for each batch. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Dictionary containing probability, prediction and ground truth values. + """ + del args, kwargs # These variables are not used. + + batch["pred_scores"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return DFKDE-specific trainer arguments.""" + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/dfkde/torch_model.py b/anomalib/models/image/dfkde/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1df08dd6a637f2b122f6844ca19f4a4c9cecc5a2 --- /dev/null +++ b/anomalib/models/image/dfkde/torch_model.py @@ -0,0 +1,87 @@ +"""Normality model of DFKDE.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections.abc import Sequence + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.classification import FeatureScalingMethod, KDEClassifier + +logger = logging.getLogger(__name__) + + +class DfkdeModel(nn.Module): + """Normality Model for the DFKDE algorithm. + + Args: + backbone (str): Pre-trained model backbone. + layers (Sequence[str]): Layers to extract features from. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + n_pca_components (int, optional): Number of PCA components. + Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod, optional): Feature scaling method. + Defaults to ``FeatureScalingMethod.SCALE``. + max_training_points (int, optional): Number of training points to fit the KDE model. + Defaults to ``40000``. + """ + + def __init__( + self, + backbone: str, + layers: Sequence[str], + pre_trained: bool = True, + n_pca_components: int = 16, + feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE, + max_training_points: int = 40000, + ) -> None: + super().__init__() + + self.feature_extractor = TimmFeatureExtractor(backbone=backbone, pre_trained=pre_trained, layers=layers).eval() + + self.classifier = KDEClassifier( + n_pca_components=n_pca_components, + feature_scaling_method=feature_scaling_method, + max_training_points=max_training_points, + ) + + def get_features(self, batch: torch.Tensor) -> torch.Tensor: + """Extract features from the pretrained network. + + Args: + batch (torch.Tensor): Image batch. + + Returns: + Tensor: torch.Tensor containing extracted features. + """ + self.feature_extractor.eval() + layer_outputs = self.feature_extractor(batch) + for layer in layer_outputs: + batch_size = len(layer_outputs[layer]) + layer_outputs[layer] = F.adaptive_avg_pool2d(input=layer_outputs[layer], output_size=(1, 1)) + layer_outputs[layer] = layer_outputs[layer].view(batch_size, -1) + return torch.cat(list(layer_outputs.values())).detach() + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Prediction by normality model. + + Args: + batch (torch.Tensor): Input images. + + Returns: + Tensor: Predictions + """ + # 1. apply feature extraction + features = self.get_features(batch) + if self.training: + return features + + # 2. apply density estimation + return self.classifier(features) diff --git a/anomalib/models/image/dfm/README.md b/anomalib/models/image/dfm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e658b0c38357530c7ffdf6cf0352781fb0a481cd --- /dev/null +++ b/anomalib/models/image/dfm/README.md @@ -0,0 +1,43 @@ +# Probabilistic Modeling of Deep Features for Out-of-Distribution and Adversarial Detection + +This is the implementation of [DFM](https://arxiv.org/pdf/1909.11786.pdf) paper. + +Model Type: Classification + +## Description + +Fast anomaly classification algorithm that consists of a deep feature extraction stage followed by anomaly classification stage consisting of PCA and class-conditional Gaussian Density Estimation. + +### Feature Extraction + +Features are extracted by feeding the images through a ResNet18 backbone, which was pre-trained on ImageNet. The output of the penultimate layer (average pooling layer) of the network is used to obtain a semantic feature vector with a fixed length of 2048. + +### Anomaly Detection + +In the anomaly classification stage, class-conditional PCA transformations and Gaussian Density models are learned. Two types of scores are calculated (i) Feature-reconstruction scores (norm of the difference between the high-dimensional pre-image of a reduced dimension feature and the original high-dimensional feature), and (ii) Negative log-likelihood under the learnt density models. Anomaly map generation is supported only with the feature-reconstruction based scores. Image level anomaly detection is supported by both score types. + +## Usage + +`python tools/train.py --model dfm` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +> Note: Metrics for ResNet 18 were calculated with pooling kernel size of 2 while for Wide ResNet 50, kernel size of 4 was used. + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.936 | 0.817 | 0.736 | 0.993 | 0.966 | 0.977 | 1 | 0.956 | 0.944 | 0.994 | 0.922 | 0.961 | 0.89 | 0.969 | 0.939 | 0.969 | +| Wide ResNet-50 | 0.943 | 0.855 | 0.784 | 0.997 | 0.995 | 0.975 | 0.999 | 0.969 | 0.924 | 0.978 | 0.939 | 0.962 | 0.873 | 0.969 | 0.971 | 0.961 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :--: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.943 | 0.895 | 0.871 | 0.978 | 0.958 | 0.96 | 1 | 0.935 | 0.965 | 0.966 | 0.942 | 0.956 | 0.914 | 0.966 | 0.868 | 0.964 | +| Wide ResNet-50 | 0.950 | 0.915 | 0.87 | 0.995 | 0.988 | 0.96 | 0.992 | 0.939 | 0.965 | 0.971 | 0.942 | 0.956 | 0.906 | 0.966 | 0.914 | 0.971 | diff --git a/anomalib/models/image/dfm/__init__.py b/anomalib/models/image/dfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c003420afc1b1da102a825ea3355213f3be06044 --- /dev/null +++ b/anomalib/models/image/dfm/__init__.py @@ -0,0 +1,8 @@ +"""Deep Feature Extraction (DFM) model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Dfm + +__all__ = ["Dfm"] diff --git a/anomalib/models/image/dfm/lightning_model.py b/anomalib/models/image/dfm/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fa68f7afcec3a61bc1a4d4d45e1312ba22ec214f --- /dev/null +++ b/anomalib/models/image/dfm/lightning_model.py @@ -0,0 +1,129 @@ +"""DFM: Deep Feature Modeling. + +https://arxiv.org/abs/1909.11786 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin + +from .torch_model import DFMModel + +logger = logging.getLogger(__name__) + + +class Dfm(MemoryBankMixin, AnomalyModule): + """DFM: Deep Featured Kernel Density Estimation. + + Args: + backbone (str): Backbone CNN network + Defaults to ``"resnet50"``. + layer (str): Layer to extract features from the backbone CNN + Defaults to ``"layer3"``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + Defaults to ``4``. + pca_level (float, optional): Ratio from which number of components for PCA are calculated. + Defaults to ``0.97``. + score_type (str, optional): Scoring type. Options are `fre` and `nll`. + Defaults to ``fre``. + """ + + def __init__( + self, + backbone: str = "resnet50", + layer: str = "layer3", + pre_trained: bool = True, + pooling_kernel_size: int = 4, + pca_level: float = 0.97, + score_type: str = "fre", + ) -> None: + super().__init__() + + self.model: DFMModel = DFMModel( + backbone=backbone, + pre_trained=pre_trained, + layer=layer, + pooling_kernel_size=pooling_kernel_size, + n_comps=pca_level, + score_type=score_type, + ) + self.embeddings: list[torch.Tensor] = [] + self.score_type = score_type + + @staticmethod + def configure_optimizers() -> None: # pylint: disable=arguments-differ + """DFM doesn't require optimization, therefore returns no optimizers.""" + return + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + """Perform the training step of DFM. + + For each batch, features are extracted from the CNN. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Deep CNN features. + """ + del args, kwargs # These variables are not used. + + embedding = self.model.get_features(batch["image"]).squeeze() + self.embeddings.append(embedding) + + def fit(self) -> None: + """Fit a PCA transformation and a Gaussian model to dataset.""" + logger.info("Aggregating the embedding extracted from the training set.") + embeddings = torch.vstack(self.embeddings) + + logger.info("Fitting a PCA and a Gaussian model to dataset.") + self.model.fit(embeddings) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of DFM. + + Similar to the training step, features are extracted from the CNN for each batch. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Dictionary containing FRE anomaly scores and anomaly maps. + """ + del args, kwargs # These variables are not used. + + if self.score_type == "fre": + batch["pred_scores"], batch["anomaly_maps"] = self.model(batch["image"]) + elif self.score_type == "nll": + batch["pred_scores"], _ = self.model(batch["image"]) + + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return DFM-specific trainer arguments.""" + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/dfm/torch_model.py b/anomalib/models/image/dfm/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7ea30f91e171194db596e3cf553ba0b7001dff --- /dev/null +++ b/anomalib/models/image/dfm/torch_model.py @@ -0,0 +1,180 @@ +"""PyTorch model for DFM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import math + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import PCA, DynamicBufferMixin, TimmFeatureExtractor + + +class SingleClassGaussian(DynamicBufferMixin): + """Model Gaussian distribution over a set of points.""" + + def __init__(self) -> None: + super().__init__() + self.register_buffer("mean_vec", torch.Tensor()) + self.register_buffer("u_mat", torch.Tensor()) + self.register_buffer("sigma_mat", torch.Tensor()) + + self.mean_vec: torch.Tensor + self.u_mat: torch.Tensor + self.sigma_mat: torch.Tensor + + def fit(self, dataset: torch.Tensor) -> None: + """Fit a Gaussian model to dataset X. + + Covariance matrix is not calculated directly using: + ``C = X.X^T`` + Instead, it is represented in terms of the Singular Value Decomposition of X: + ``X = U.S.V^T`` + Hence, + ``C = U.S^2.U^T`` + This simplifies the calculation of the log-likelihood without requiring full matrix inversion. + + Args: + dataset (torch.Tensor): Input dataset to fit the model. + """ + num_samples = dataset.shape[1] + self.mean_vec = torch.mean(dataset, dim=1) + data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples) + self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False) + + def score_samples(self, features: torch.Tensor) -> torch.Tensor: + """Compute the NLL (negative log likelihood) scores. + + Args: + features (torch.Tensor): semantic features on which density modeling is performed. + + Returns: + nll (torch.Tensor): Torch tensor of scores + """ + features_transformed = torch.matmul(features - self.mean_vec, self.u_mat / self.sigma_mat) + return torch.sum(features_transformed * features_transformed, dim=1) + 2 * torch.sum(torch.log(self.sigma_mat)) + + def forward(self, dataset: torch.Tensor) -> None: + """Provide the same functionality as `fit`. + + Transforms the input dataset based on singular values calculated earlier. + + Args: + dataset (torch.Tensor): Input dataset + """ + self.fit(dataset) + + +class DFMModel(nn.Module): + """Model for the DFM algorithm. + + Args: + backbone (str): Pre-trained model backbone. + layer (str): Layer from which to extract features. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + Defaults to ``4``. + n_comps (float, optional): Ratio from which number of components for PCA are calculated. + Defaults to ``0.97``. + score_type (str, optional): Scoring type. Options are `fre` and `nll`. Anomaly + Defaults to ``fre``. Segmentation is supported with `fre` only. + If using `nll`, set `task` in config.yaml to classification Defaults to ``classification``. + """ + + def __init__( + self, + backbone: str, + layer: str, + pre_trained: bool = True, + pooling_kernel_size: int = 4, + n_comps: float = 0.97, + score_type: str = "fre", + ) -> None: + super().__init__() + self.backbone = backbone + self.pooling_kernel_size = pooling_kernel_size + self.n_components = n_comps + self.pca_model = PCA(n_components=self.n_components) + self.gaussian_model = SingleClassGaussian() + self.score_type = score_type + self.layer = layer + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=pre_trained, + layers=[layer], + ).eval() + + def fit(self, dataset: torch.Tensor) -> None: + """Fit a pca transformation and a Gaussian model to dataset. + + Args: + dataset (torch.Tensor): Input dataset to fit the model. + """ + self.pca_model.fit(dataset) + if self.score_type == "nll": + features_reduced = self.pca_model.transform(dataset) + self.gaussian_model.fit(features_reduced.T) + + def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: + """Compute scores. + + Scores are either PCA-based feature reconstruction error (FRE) scores or + the Gaussian density-based NLL scores + + Args: + features (torch.Tensor): semantic features on which PCA and density modeling is performed. + feature_shapes (tuple): shape of `features` tensor. Used to generate anomaly map of correct shape. + + Returns: + score (torch.Tensor): numpy array of scores + """ + feats_projected = self.pca_model.transform(features) + if self.score_type == "nll": + score = self.gaussian_model.score_samples(feats_projected) + elif self.score_type == "fre": + feats_reconstructed = self.pca_model.inverse_transform(feats_projected) + fre = torch.square(features - feats_reconstructed).reshape(feature_shapes) + score_map = torch.unsqueeze(torch.sum(fre, dim=1), 1) + score = torch.sum(torch.square(features - feats_reconstructed), dim=1) + else: + msg = f"unsupported score type: {self.score_type}" + raise ValueError(msg) + + return (score, None) if self.score_type == "nll" else (score, score_map) + + def get_features(self, batch: torch.Tensor) -> torch.Tensor: + """Extract features from the pretrained network. + + Args: + batch (torch.Tensor): Image batch. + + Returns: + Tensor: torch.Tensor containing extracted features. + """ + self.feature_extractor.eval() + features = self.feature_extractor(batch)[self.layer] + batch_size = len(features) + if self.pooling_kernel_size > 1: + features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size) + feature_shapes = features.shape + features = features.view(batch_size, -1).detach() + return features if self.training else (features, feature_shapes) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Compute score from input images. + + Args: + batch (torch.Tensor): Input images + + Returns: + Tensor: Scores + """ + feature_vector, feature_shapes = self.get_features(batch) + score, score_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) + if score_map is not None: + score_map = F.interpolate(score_map, size=batch.shape[-2:], mode="bilinear", align_corners=False) + return score, score_map diff --git a/anomalib/models/image/draem/LICENSE b/anomalib/models/image/draem/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7025d18fb4e1f8a5ae53c457494868af7de716bd --- /dev/null +++ b/anomalib/models/image/draem/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2022 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original DRAEM implementation by VitjanZ + +Original license: +---------------- + + MIT License + + Copyright (c) 2021 VitjanZ + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/anomalib/models/image/draem/README.md b/anomalib/models/image/draem/README.md new file mode 100644 index 0000000000000000000000000000000000000000..007faa0220534e89584774bef45d1de59921f4fc --- /dev/null +++ b/anomalib/models/image/draem/README.md @@ -0,0 +1,23 @@ +# DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection + +This is the implementation of the [DRAEM](https://arxiv.org/pdf/2108.07610v2.pdf) paper. + +Model Type: Segmentation + +## Description + +DRAEM is a reconstruction based algorithm that consists of a reconstructive subnetwork and a discriminative subnetwork. DRAEM is trained on simulated anomaly images, generated by augmenting normal input images from the training set with a random Perlin noise mask extracted from an unrelated source of image data. The reconstructive subnetwork is an autoencoder architecture that is trained to reconstruct the original input images from the augmented images. The reconstructive submodel is trained using a combination of L2 loss and Structural Similarity loss. The input of the discriminative subnetwork consists of the channel-wise concatenation of the (augmented) input image and the output of the reconstructive subnetwork. The output of the discriminative subnetwork is an anomaly map that contains the predicted anomaly scores for each pixel location. The discriminative subnetwork is trained using Focal Loss. + +For optimal results, DRAEM requires specifying the path to a folder of image data that will be used as the source of the anomalous pixel regions in the simulated anomaly images. The path can be specified by editing the value of the `model.anomaly_source_path` parameter in the `config.yaml` file. The authors of the original paper recommend using the [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset as anomaly source. + +## Architecture + +![DRAEM Architecture](/docs/source/images/draem/architecture.png "DRAEM Architecture") + +## Usage + +`python tools/train.py --model draem` + +## Benchmark + +### TODO: Add results + benchmark diff --git a/anomalib/models/image/draem/__init__.py b/anomalib/models/image/draem/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c8b06fa1dd55095501b5158a1930fd22db29415 --- /dev/null +++ b/anomalib/models/image/draem/__init__.py @@ -0,0 +1,8 @@ +"""DRAEM model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Draem + +__all__ = ["Draem"] diff --git a/anomalib/models/image/draem/lightning_model.py b/anomalib/models/image/draem/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..037b397f93d704bfefdc323ef6ad753b90c1d522 --- /dev/null +++ b/anomalib/models/image/draem/lightning_model.py @@ -0,0 +1,153 @@ +"""DRÆM - A discriminatively trained reconstruction embedding for surface anomaly detection. + +Paper https://arxiv.org/abs/2108.07610 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn + +from anomalib import LearningType +from anomalib.data.utils import Augmenter +from anomalib.models.components import AnomalyModule + +from .loss import DraemLoss +from .torch_model import DraemModel + +__all__ = ["Draem"] + + +class Draem(AnomalyModule): + """DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection. + + Args: + enable_sspcab (bool): Enable SSPCAB training. + Defaults to ``False``. + sspcab_lambda (float): SSPCAB loss weight. + Defaults to ``0.1``. + anomaly_source_path (str | None): Path to folder that contains the anomaly source images. Random noise will + be used if left empty. + Defaults to ``None``. + """ + + def __init__( + self, + enable_sspcab: bool = False, + sspcab_lambda: float = 0.1, + anomaly_source_path: str | None = None, + beta: float | tuple[float, float] = (0.1, 1.0), + ) -> None: + super().__init__() + + self.augmenter = Augmenter(anomaly_source_path, beta=beta) + self.model = DraemModel(sspcab=enable_sspcab) + self.loss = DraemLoss() + self.sspcab = enable_sspcab + + if self.sspcab: + self.sspcab_activations: dict = {} + self.setup_sspcab() + self.sspcab_loss = nn.MSELoss() + self.sspcab_lambda = sspcab_lambda + + def setup_sspcab(self) -> None: + """Prepare the model for the SSPCAB training step by adding forward hooks for the SSPCAB layer activations.""" + + def get_activation(name: str) -> Callable: + """Retrieve the activations. + + Args: + name (str): Identifier for the retrieved activations. + """ + + def hook(_, __, output: torch.Tensor) -> None: # noqa: ANN001 + """Create hook for retrieving the activations. + + Args: + _: Placeholder for the module input. + __: Placeholder for the module output. + output (torch.Tensor): The output tensor of the module. + """ + self.sspcab_activations[name] = output + + return hook + + self.model.reconstructive_subnetwork.encoder.mp4.register_forward_hook(get_activation("input")) + self.model.reconstructive_subnetwork.encoder.block5.register_forward_hook(get_activation("output")) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step of DRAEM. + + Feeds the original image and the simulated anomaly + image through the network and computes the training loss. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Loss dictionary + """ + del args, kwargs # These variables are not used. + + input_image = batch["image"] + # Apply corruption to input image + augmented_image, anomaly_mask = self.augmenter.augment_batch(input_image) + # Generate model prediction + reconstruction, prediction = self.model(augmented_image) + # Compute loss + loss = self.loss(input_image, reconstruction, anomaly_mask, prediction) + + if self.sspcab: + loss += self.sspcab_lambda * self.sspcab_loss( + self.sspcab_activations["input"], + self.sspcab_activations["output"], + ) + + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of DRAEM. The Softmax predictions of the anomalous class are used as anomaly map. + + Args: + batch (dict[str, str | torch.Tensor]): Batch of input images + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Dictionary to which predicted anomaly maps have been added. + """ + del args, kwargs # These variables are not used. + + prediction = self.model(batch["image"]) + batch["anomaly_maps"] = prediction + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return DRÆM-specific trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure the Adam optimizer.""" + optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.0001) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 600], gamma=0.1) + return [optimizer], [scheduler] + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/draem/loss.py b/anomalib/models/image/draem/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1cef702e154fb9aa2ef6eeb7ea3c1a95ccad5232 --- /dev/null +++ b/anomalib/models/image/draem/loss.py @@ -0,0 +1,36 @@ +"""Loss function for the DRAEM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from kornia.losses import FocalLoss, SSIMLoss +from torch import nn + + +class DraemLoss(nn.Module): + """Overall loss function of the DRAEM model. + + The total loss consists of the sum of the L2 loss and Focal loss between the reconstructed image and the input + image, and the Structural Similarity loss between the predicted and GT anomaly masks. + """ + + def __init__(self) -> None: + super().__init__() + + self.l2_loss = nn.modules.loss.MSELoss() + self.focal_loss = FocalLoss(alpha=1, reduction="mean") + self.ssim_loss = SSIMLoss(window_size=11) + + def forward( + self, + input_image: torch.Tensor, + reconstruction: torch.Tensor, + anomaly_mask: torch.Tensor, + prediction: torch.Tensor, + ) -> torch.Tensor: + """Compute the loss over a batch for the DRAEM model.""" + l2_loss_val = self.l2_loss(reconstruction, input_image) + focal_loss_val = self.focal_loss(prediction, anomaly_mask.squeeze(1).long()) + ssim_loss_val = self.ssim_loss(reconstruction, input_image) * 2 + return l2_loss_val + ssim_loss_val + focal_loss_val diff --git a/anomalib/models/image/draem/torch_model.py b/anomalib/models/image/draem/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..86d489204f9c53272db5a590741be9f9ab483d25 --- /dev/null +++ b/anomalib/models/image/draem/torch_model.py @@ -0,0 +1,521 @@ +"""PyTorch model for the DRAEM model implementation.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + +from anomalib.models.components.layers import SSPCAB + + +class DraemModel(nn.Module): + """DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks. + + Args: + sspcab (bool): Enable SSPCAB training. + Defaults to ``False``. + """ + + def __init__(self, sspcab: bool = False) -> None: + super().__init__() + self.reconstructive_subnetwork = ReconstructiveSubNetwork(sspcab=sspcab) + self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2) + + def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Compute the reconstruction and anomaly mask from an input image. + + Args: + batch (torch.Tensor): batch of input images + + Returns: + Predicted confidence values of the anomaly mask. During training the reconstructed input images are + returned as well. + """ + reconstruction = self.reconstructive_subnetwork(batch) + concatenated_inputs = torch.cat([batch, reconstruction], axis=1) + prediction = self.discriminative_subnetwork(concatenated_inputs) + if self.training: + return reconstruction, prediction + return torch.softmax(prediction, dim=1)[:, 1, ...] + + +class ReconstructiveSubNetwork(nn.Module): + """Autoencoder model that encodes and reconstructs the input image. + + Args: + in_channels (int): Number of input channels. + Defaults to ``3``. + out_channels (int): Number of output channels. + Defaults to ``3``. + base_width (int): Base dimensionality of the layers of the autoencoder. + Defaults to ``128``. + sspcab (bool): Enable SSPCAB training. + Defaults to ``False``. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + base_width: int = 128, + sspcab: bool = False, + ) -> None: + super().__init__() + self.encoder = EncoderReconstructive(in_channels, base_width, sspcab=sspcab) + self.decoder = DecoderReconstructive(base_width, out_channels=out_channels) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Encode and reconstruct the input images. + + Args: + batch (torch.Tensor): Batch of input images + + Returns: + Batch of reconstructed images. + """ + encoded = self.encoder(batch) + return self.decoder(encoded) + + +class DiscriminativeSubNetwork(nn.Module): + """Discriminative model that predicts the anomaly mask from the original image and its reconstruction. + + Args: + in_channels (int): Number of input channels. + Defaults to ``3``. + out_channels (int): Number of output channels. + Defaults to ``3``. + base_width (int): Base dimensionality of the layers of the autoencoder. + Defaults to ``64``. + """ + + def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width: int = 64) -> None: + super().__init__() + self.encoder_segment = EncoderDiscriminative(in_channels, base_width) + self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Generate the predicted anomaly masks for a batch of input images. + + Args: + batch (torch.Tensor): Batch of inputs consisting of the concatenation of the original images + and their reconstructions. + + Returns: + Activations of the output layer corresponding to the normal and anomalous class scores on the pixel level. + """ + act1, act2, act3, act4, act5, act6 = self.encoder_segment(batch) + return self.decoder_segment(act1, act2, act3, act4, act5, act6) + + +class EncoderDiscriminative(nn.Module): + """Encoder part of the discriminator network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, base_width: int) -> None: + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.mp3 = nn.Sequential(nn.MaxPool2d(2)) + self.block4 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.mp4 = nn.Sequential(nn.MaxPool2d(2)) + self.block5 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + self.mp5 = nn.Sequential(nn.MaxPool2d(2)) + self.block6 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + def forward( + self, + batch: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert the inputs to the salient space by running them through the encoder network. + + Args: + batch (torch.Tensor): Batch of inputs consisting of the concatenation of the original images + and their reconstructions. + + Returns: + Computed feature maps for each of the layers in the encoder sub network. + """ + act1 = self.block1(batch) + mp1 = self.mp1(act1) + act2 = self.block2(mp1) + mp2 = self.mp3(act2) + act3 = self.block3(mp2) + mp3 = self.mp3(act3) + act4 = self.block4(mp3) + mp4 = self.mp4(act4) + act5 = self.block5(mp4) + mp5 = self.mp5(act5) + act6 = self.block6(mp5) + return act1, act2, act3, act4, act5, act6 + + +class DecoderDiscriminative(nn.Module): + """Decoder part of the discriminator network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + Defaults to ``1``. + """ + + def __init__(self, base_width: int, out_channels: int = 1) -> None: + super().__init__() + + self.up_b = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.db_b = nn.Sequential( + nn.Conv2d(base_width * (8 + 8), base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.db1 = nn.Sequential( + nn.Conv2d(base_width * (4 + 8), base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.db2 = nn.Sequential( + nn.Conv2d(base_width * (2 + 4), base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db3 = nn.Sequential( + nn.Conv2d(base_width * (2 + 1), base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db4 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward( + self, + act1: torch.Tensor, + act2: torch.Tensor, + act3: torch.Tensor, + act4: torch.Tensor, + act5: torch.Tensor, + act6: torch.Tensor, + ) -> torch.Tensor: + """Compute predicted anomaly class scores from the intermediate outputs of the encoder sub network. + + Args: + act1 (torch.Tensor): Encoder activations of the first block of convolutional layers. + act2 (torch.Tensor): Encoder activations of the second block of convolutional layers. + act3 (torch.Tensor): Encoder activations of the third block of convolutional layers. + act4 (torch.Tensor): Encoder activations of the fourth block of convolutional layers. + act5 (torch.Tensor): Encoder activations of the fifth block of convolutional layers. + act6 (torch.Tensor): Encoder activations of the sixth block of convolutional layers. + + Returns: + Predicted anomaly class scores per pixel. + """ + up_b = self.up_b(act6) + cat_b = torch.cat((up_b, act5), dim=1) + db_b = self.db_b(cat_b) + + up1 = self.up1(db_b) + cat1 = torch.cat((up1, act4), dim=1) + db1 = self.db1(cat1) + + up2 = self.up2(db1) + cat2 = torch.cat((up2, act3), dim=1) + db2 = self.db2(cat2) + + up3 = self.up3(db2) + cat3 = torch.cat((up3, act2), dim=1) + db3 = self.db3(cat3) + + up4 = self.up4(db3) + cat4 = torch.cat((up4, act1), dim=1) + db4 = self.db4(cat4) + + return self.fin_out(db4) + + +class EncoderReconstructive(nn.Module): + """Encoder part of the reconstructive network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + sspcab (bool): Enable SSPCAB training. + Defaults to ``False``. + """ + + def __init__(self, in_channels: int, base_width: int, sspcab: bool = False) -> None: + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.mp3 = nn.Sequential(nn.MaxPool2d(2)) + self.block4 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.mp4 = nn.Sequential(nn.MaxPool2d(2)) + if sspcab: + self.block5 = SSPCAB(base_width * 8) + else: + self.block5 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Encode a batch of input images to the salient space. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Feature maps extracted from the bottleneck layer. + """ + act1 = self.block1(batch) + mp1 = self.mp1(act1) + act2 = self.block2(mp1) + mp2 = self.mp3(act2) + act3 = self.block3(mp2) + mp3 = self.mp3(act3) + act4 = self.block4(mp3) + mp4 = self.mp4(act4) + return self.block5(mp4) + + +class DecoderReconstructive(nn.Module): + """Decoder part of the reconstructive network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + Defaults to ``1``. + """ + + def __init__(self, base_width: int, out_channels: int = 1) -> None: + super().__init__() + + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.db1 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.db2 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + # cat with base*1 + self.db3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 1, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 1), + nn.ReLU(inplace=True), + ) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db4 = nn.Sequential( + nn.Conv2d(base_width * 1, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward(self, act5: torch.Tensor) -> torch.Tensor: + """Reconstruct the image from the activations of the bottleneck layer. + + Args: + act5 (torch.Tensor): Activations of the bottleneck layer. + + Returns: + Batch of reconstructed images. + """ + up1 = self.up1(act5) + db1 = self.db1(up1) + + up2 = self.up2(db1) + db2 = self.db2(up2) + + up3 = self.up3(db2) + db3 = self.db3(up3) + + up4 = self.up4(db3) + db4 = self.db4(up4) + + return self.fin_out(db4) diff --git a/anomalib/models/image/dsr/LICENSE b/anomalib/models/image/dsr/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0fbbaf4eb90cad537e97d3610b27a59fabb7057b --- /dev/null +++ b/anomalib/models/image/dsr/LICENSE @@ -0,0 +1,209 @@ +Copyright (c) 2023 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original DSR implementation by VitjanZ + +Original license: +---------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/anomalib/models/image/dsr/README.md b/anomalib/models/image/dsr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fcf330817d5385b1ba72064415d4c3814819bde8 --- /dev/null +++ b/anomalib/models/image/dsr/README.md @@ -0,0 +1,21 @@ +# DSR – A Dual Subspace Re-Projection Network for Surface Anomaly Detection + +This is the implementation of the [DSR](https://link.springer.com/chapter/10.1007/978-3-031-19821-2_31) paper. + +Model Type: Segmentation + +## Description + +DSR is a quantized-feature based algorithm that consists of an autoencoder with one encoder and two decoders, coupled with an anomaly detection module. DSR learns a codebook of quantized representations on ImageNet, which are then used to encode input images. These quantized representations also serve to sample near-in-distribution anomalies, since they do not rely on external datasets. Training takes place in three phases. The encoder and "general object decoder", as well as the codebook, are pretrained on ImageNet. Defects are then generated at the feature level using the codebook on the quantized representations, and are used to train the object-specific decoder as well as the anomaly detection module. In the final phase of training, the upsampling module is trained on simulated image-level smudges in order to output more robust anomaly maps. + +## Architecture + +![DSR Architecture](/docs/source/images/dsr/architecture.png "DSR Architecture") + +## Usage + +`python tools/train.py --model dsr` + +## Benchmark + +Benchmarking results are not yet available for this algorithm. Please check again later. diff --git a/anomalib/models/image/dsr/__init__.py b/anomalib/models/image/dsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54e53d5d6fb778f59b779b1a9886ce41e6a9a004 --- /dev/null +++ b/anomalib/models/image/dsr/__init__.py @@ -0,0 +1,8 @@ +"""DSR model.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Dsr + +__all__ = ["Dsr"] diff --git a/anomalib/models/image/dsr/anomaly_generator.py b/anomalib/models/image/dsr/anomaly_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..cf663f70bababa1ac8171e2715fdbce6b74c6595 --- /dev/null +++ b/anomalib/models/image/dsr/anomaly_generator.py @@ -0,0 +1,77 @@ +"""Anomaly generator for the DSR model implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import imgaug.augmenters as iaa +import numpy as np +import torch +from torch import Tensor, nn + +from anomalib.data.utils.generators.perlin import _rand_perlin_2d_np + + +class DsrAnomalyGenerator(nn.Module): + """Anomaly generator of the DSR model. + + The anomaly is generated using a Perlin noise generator on the two quantized representations of an image. + This generator is only used during the second phase of training! The third phase requires generating + smudges over the input images. + + Args: + p_anomalous (float, optional): Probability to generate an anomalous image. + """ + + def __init__( + self, + p_anomalous: float = 0.5, + ) -> None: + super().__init__() + + self.p_anomalous = p_anomalous + self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) + + def generate_anomaly(self, height: int, width: int) -> Tensor: + """Generate an anomalous mask. + + Args: + height (int): Height of generated mask. + width (int): Width of generated mask. + + Returns: + Tensor: Generated mask. + """ + min_perlin_scale = 0 + perlin_scale = 6 + perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) + perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) + threshold = 0.5 + perlin_noise_np = _rand_perlin_2d_np((height, width), (perlin_scalex, perlin_scaley)) + perlin_noise_np = self.rot(image=perlin_noise_np) + mask = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np)) + mask = np.expand_dims(mask, axis=2).astype(np.float32) + + return torch.from_numpy(mask) + + def augment_batch(self, batch: Tensor) -> Tensor: + """Generate anomalous augmentations for a batch of input images. + + Args: + batch (Tensor): Batch of input images + + Returns: + Tensor: Ground truth masks corresponding to the anomalous perturbations. + """ + batch_size, _, height, width = batch.shape + + # Collect perturbations + masks_list: list[Tensor] = [] + for _ in range(batch_size): + if torch.rand(1) > self.p_anomalous: # include normal samples + masks_list.append(torch.zeros((1, height, width))) + else: + mask = self.generate_anomaly(height, width) + masks_list.append(mask.permute((2, 0, 1))) + + return torch.stack(masks_list).to(batch.device) diff --git a/anomalib/models/image/dsr/lightning_model.py b/anomalib/models/image/dsr/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e9eb4d269305538c28972c24aac4378b55f00247 --- /dev/null +++ b/anomalib/models/image/dsr/lightning_model.py @@ -0,0 +1,192 @@ +"""DSR - A Dual Subspace Re-Projection Network for Surface Anomaly Detection. + +Paper https://link.springer.com/chapter/10.1007/978-3-031-19821-2_31 +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler +from torch import Tensor + +from anomalib import LearningType +from anomalib.data.utils import DownloadInfo, download_and_extract +from anomalib.data.utils.augmenter import Augmenter +from anomalib.models.components import AnomalyModule +from anomalib.models.image.dsr.anomaly_generator import DsrAnomalyGenerator +from anomalib.models.image.dsr.loss import DsrSecondStageLoss, DsrThirdStageLoss +from anomalib.models.image.dsr.torch_model import DsrModel + +__all__ = ["Dsr"] + +logger = logging.getLogger(__name__) + +WEIGHTS_DOWNLOAD_INFO = DownloadInfo( + name="vq_model_pretrained_128_4096.pckl", + url="https://github.com/openvinotoolkit/anomalib/releases/download/dsr_pretrained_weights/dsr_vq_model_pretrained.zip", + hashsum="52fe7504ec8e9df70b4382f287ab26269dcfe000cd7a7e146a52c6f146f34afb", +) + + +class Dsr(AnomalyModule): + """DSR: A Dual Subspace Re-Projection Network for Surface Anomaly Detection. + + Args: + latent_anomaly_strength (float): Strength of the generated anomalies in the latent space. Defaults to 0.2 + upsampling_train_ratio (float): Ratio of training steps for the upsampling module. Defaults to 0.7 + """ + + def __init__(self, latent_anomaly_strength: float = 0.2, upsampling_train_ratio: float = 0.7) -> None: + super().__init__() + + self.automatic_optimization = False + self.upsampling_train_ratio = upsampling_train_ratio + + self.quantized_anomaly_generator = DsrAnomalyGenerator() + self.perlin_generator = Augmenter() + self.model = DsrModel(latent_anomaly_strength) + self.second_stage_loss = DsrSecondStageLoss() + self.third_stage_loss = DsrThirdStageLoss() + + self.second_phase: int + + def prepare_pretrained_model(self) -> Path: + """Download pre-trained models if they don't exist.""" + pretrained_models_dir = Path("./pre_trained/") + if not (pretrained_models_dir / "vq_model_pretrained_128_4096.pckl").is_file(): + download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) + return pretrained_models_dir / "vq_model_pretrained_128_4096.pckl" + + def configure_optimizers( + self, + ) -> OptimizerLRScheduler: + """Configure the Adam optimizer for training phases 2 and 3. + + Does not train the discrete model (phase 1) + + Returns: + dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: Dictionary of optimizers + """ + num_steps = max( + self.trainer.max_steps // len(self.trainer.datamodule.train_dataloader()), + self.trainer.max_epochs, + ) + self.second_phase = int(num_steps * self.upsampling_train_ratio) + anneal = int(0.8 * self.second_phase) + optimizer_d = torch.optim.Adam( + params=list(self.model.image_reconstruction_network.parameters()) + + list(self.model.subspace_restriction_module_hi.parameters()) + + list(self.model.subspace_restriction_module_lo.parameters()) + + list(self.model.anomaly_detection_module.parameters()), + lr=0.0002, + ) + scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_d, anneal, gamma=0.1) + + optimizer_u = torch.optim.Adam(params=self.model.upsampling_module.parameters(), lr=0.0002) + + return ({"optimizer": optimizer_d, "lr_scheduler": scheduler_d}, {"optimizer": optimizer_u}) + + def on_train_start(self) -> None: + """Load pretrained weights of the discrete model when starting training.""" + ckpt: Path = self.prepare_pretrained_model() + self.model.load_pretrained_discrete_model_weights(ckpt, self.device) + + def on_train_epoch_start(self) -> None: + """Display a message when starting to train the upsampling module.""" + if self.current_epoch == self.second_phase: + logger.info("Now training upsampling module.") + + def training_step(self, batch: dict[str, str | Tensor]) -> STEP_OUTPUT: + """Training Step of DSR. + + Feeds the original image and the simulated anomaly mask during first phase. During + second phase, feeds a generated anomalous image to train the upsampling module. + + Args: + batch (dict[str, str | Tensor]): Batch containing image filename, image, label and mask + + Returns: + STEP_OUTPUT: Loss dictionary + """ + ph1_opt, ph2_opt = self.optimizers() + + if self.current_epoch < self.second_phase: + # we are not yet training the upsampling module: we are only using the first optimizer + input_image = batch["image"] + # Create anomaly masks + anomaly_mask = self.quantized_anomaly_generator.augment_batch(input_image) + # Generate model prediction + model_outputs = self.model(input_image, anomaly_mask) + # Compute loss + loss = self.second_stage_loss( + model_outputs["recon_feat_hi"], + model_outputs["recon_feat_lo"], + model_outputs["embedding_bot"], + model_outputs["embedding_top"], + input_image, + model_outputs["obj_spec_image"], + model_outputs["anomaly_map"], + model_outputs["true_anomaly_map"], + ) + + # compute manual optimizer step + ph1_opt.zero_grad() + self.manual_backward(loss) + ph1_opt.step() + + else: + # we are training the upsampling module + input_image = batch["image"] + # Generate anomalies + input_image, anomaly_maps = self.perlin_generator.augment_batch(input_image) + # Get model prediction + model_outputs = self.model(input_image) + # Calculate loss + loss = self.third_stage_loss(model_outputs["anomaly_map"], anomaly_maps) + + # compute manual optimizer step + ph2_opt.zero_grad() + self.manual_backward(loss) + ph2_opt.step() + + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Validation step of DSR. + + The Softmax predictions of the anomalous class are used as anomaly map. + + Args: + batch (dict[str, str | Tensor]): Batch of input images + *args: unused + **kwargs: unused + + Returns: + STEP_OUTPUT: Dictionary to which predicted anomaly maps have been added. + """ + del args, kwargs # These variables are not used. + + model_outputs = self.model(batch["image"]) + batch["anomaly_maps"] = model_outputs["anomaly_map"] + batch["pred_scores"] = model_outputs["pred_score"] + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Required trainer arguments.""" + return {"num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/dsr/loss.py b/anomalib/models/image/dsr/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f1020b9d3451a3fe9e671ada26775d3e761cac8b --- /dev/null +++ b/anomalib/models/image/dsr/loss.py @@ -0,0 +1,80 @@ +"""Loss function for the DSR model implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from kornia.losses import FocalLoss +from torch import Tensor, nn + + +class DsrSecondStageLoss(nn.Module): + """Overall loss function of the second training phase of the DSR model. + + The total loss consists of: + - MSE loss between non-anomalous quantized input image and anomalous subspace-reconstructed + non-quantized input (hi and lo) + - MSE loss between input image and reconstructed image through object-specific decoder, + - Focal loss between computed segmentation mask and ground truth mask. + """ + + def __init__(self) -> None: + super().__init__() + + self.l2_loss = nn.modules.loss.MSELoss() + self.focal_loss = FocalLoss(alpha=1, reduction="mean") + + def forward( + self, + recon_nq_hi: Tensor, + recon_nq_lo: Tensor, + qu_hi: Tensor, + qu_lo: Tensor, + input_image: Tensor, + gen_img: Tensor, + seg: Tensor, + anomaly_mask: Tensor, + ) -> Tensor: + """Compute the loss over a batch for the DSR model. + + Args: + recon_nq_hi (Tensor): Reconstructed non-quantized hi feature + recon_nq_lo (Tensor): Reconstructed non-quantized lo feature + qu_hi (Tensor): Non-defective quantized hi feature + qu_lo (Tensor): Non-defective quantized lo feature + input_image (Tensor): Original image + gen_img (Tensor): Object-specific decoded image + seg (Tensor): Computed anomaly map + anomaly_mask (Tensor): Ground truth anomaly map + + Returns: + Tensor: Total loss + """ + l2_loss_hi_val = self.l2_loss(recon_nq_hi, qu_hi) + l2_loss_lo_val = self.l2_loss(recon_nq_lo, qu_lo) + l2_loss_img_val = self.l2_loss(input_image, gen_img) * 10 + focal_loss_val = self.focal_loss(seg, anomaly_mask.squeeze(1).long()) + return l2_loss_hi_val + l2_loss_lo_val + l2_loss_img_val + focal_loss_val + + +class DsrThirdStageLoss(nn.Module): + """Overall loss function of the third training phase of the DSR model. + + The loss consists of a focal loss between the computed segmentation mask and the ground truth mask. + """ + + def __init__(self) -> None: + super().__init__() + + self.focal_loss = FocalLoss(alpha=1, reduction="mean") + + def forward(self, pred_mask: Tensor, true_mask: Tensor) -> Tensor: + """Compute the loss over a batch for the DSR model. + + Args: + pred_mask (Tensor): Computed anomaly map + true_mask (Tensor): Ground truth anomaly map + + Returns: + Tensor: Total loss + """ + return self.focal_loss(pred_mask, true_mask.squeeze(1).long()) diff --git a/anomalib/models/image/dsr/torch_model.py b/anomalib/models/image/dsr/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f11a66187bbaf8fc37e0f1cad051cc50014ca9 --- /dev/null +++ b/anomalib/models/image/dsr/torch_model.py @@ -0,0 +1,1293 @@ +"""PyTorch model for the DSR model implementation.""" + +# Original Code +# Copyright (c) 2022 VitjanZ +# https://github.com/VitjanZ/DSR_anomaly_detection. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable +from pathlib import Path + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn + + +class DsrModel(nn.Module): + """DSR PyTorch model. + + Consists of the discrete latent model, image reconstruction network, + subspace restriction modules, anomaly detection module and upsampling module. + + Args: + embedding_dim (int): Dimension of codebook embeddings. + num_embeddings (int): Number of embeddings. + latent_anomaly_strength (float): Strength of the generated anomalies in the latent space. + num_hiddens (int): Number of output channels in residual layers. + num_residual_layers (int): Number of residual layers. + num_residual_hiddens (int): Number of intermediate channels. + """ + + def __init__( + self, + latent_anomaly_strength: float = 0.2, + embedding_dim: int = 128, + num_embeddings: int = 4096, + num_hiddens: int = 128, + num_residual_layers: int = 2, + num_residual_hiddens: int = 64, + ) -> None: + super().__init__() + + self.image_dim: int = 3 + self.anomaly_map_dim: int = 2 + self.latent_anomaly_strength: float = latent_anomaly_strength + + self.discrete_latent_model = DiscreteLatentModel( + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + + self.image_reconstruction_network = ImageReconstructionNetwork( + in_channels=embedding_dim * 2, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + self.subspace_restriction_module_lo = SubspaceRestrictionModule(base_width=embedding_dim) + self.subspace_restriction_module_hi = SubspaceRestrictionModule(base_width=embedding_dim) + + self.anomaly_detection_module = AnomalyDetectionModule( + in_channels=2 * self.image_dim, + out_channels=self.anomaly_map_dim, + base_width=64, + ) + + self.upsampling_module = UpsamplingModule( + in_channels=(2 * self.image_dim) + self.anomaly_map_dim, + out_channels=self.anomaly_map_dim, + base_width=64, + ) + + for parameters in self.discrete_latent_model.parameters(): + parameters.requires_grad = False + + def load_pretrained_discrete_model_weights(self, ckpt: Path, device: torch.device | str | None = None) -> None: + """Load pre-trained model weights.""" + self.discrete_latent_model.load_state_dict(torch.load(ckpt, map_location=device)) + + def forward( + self, + batch: torch.Tensor, + anomaly_map_to_generate: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Compute the anomaly mask from an input image. + + Args: + batch (torch.Tensor): Batch of input images. + anomaly_map_to_generate (torch.Tensor | None): anomaly map to use to generate quantized defects. + If not training phase 2, should be None. + + Returns: + dict[str, torch.Tensor]: + If testing: + - "anomaly_map": Upsampled anomaly map + - "pred_score": Image score + If training phase 2: + - "recon_feat_hi": Reconstructed non-quantized hi features of defect (F~_hi) + - "recon_feat_lo": Reconstructed non-quantized lo features of defect (F~_lo) + - "embedding_bot": Quantized features of non defective img (Q_hi) + - "embedding_top": Quantized features of non defective img (Q_lo) + - "obj_spec_image": Object-specific-decoded image (I_spc) + - "anomaly_map": Predicted segmentation mask (M) + - "true_mask": Resized ground-truth anomaly map (M_gt) + If training phase 3: + - "anomaly_map": Reconstructed anomaly map + """ + outputs: dict[str, torch.Tensor] + + # Generate latent embeddings decoded image via general object decoder + if anomaly_map_to_generate is None: + # either evaluating or training phase 3 + with torch.no_grad(): + latent_model_outputs = self.discrete_latent_model(batch) + gen_image = latent_model_outputs["recon_image"] + embd_top = latent_model_outputs["quantized_t"] + embd_bot = latent_model_outputs["quantized_b"] + + # Get embedders from the discrete latent model + embedder_bot = self.discrete_latent_model.vq_vae_bot + embedder_top = self.discrete_latent_model.vq_vae_top + + # Copy embeddings in order to input them to the subspace restriction module + anomaly_embedding_bot_copy = embd_bot.clone() + anomaly_embedding_top_copy = embd_top.clone() + + # Apply subspace restriction module to copied embeddings + _, recon_embd_bot = self.subspace_restriction_module_hi(anomaly_embedding_bot_copy, embedder_bot) + _, recon_embd_top = self.subspace_restriction_module_lo(anomaly_embedding_top_copy, embedder_top) + + # Upscale top (lo) embedding + up_quantized_recon_t = self.discrete_latent_model.upsample_t(recon_embd_top) + + # Concat embeddings and reconstruct image (object specific decoder) + quant_join = torch.cat((up_quantized_recon_t, recon_embd_bot), dim=1) + obj_spec_image = self.image_reconstruction_network(quant_join) + + # Anomaly detection module + out_mask = self.anomaly_detection_module(obj_spec_image, gen_image) + out_mask_sm = torch.softmax(out_mask, dim=1) + + # Mask upsampling and score calculation + upsampled_mask = self.upsampling_module(obj_spec_image, gen_image, out_mask_sm) + out_mask_sm_up = torch.softmax(upsampled_mask, dim=1) + + # if training phase 3, return upsampled softmax mask + if self.training: + outputs = {"anomaly_map": out_mask_sm_up} + # if testing, extract image score + else: + out_mask_averaged = torch.nn.functional.avg_pool2d( + out_mask_sm[:, 1:, :, :], + 21, + stride=1, + padding=21 // 2, + ).detach() + image_score = torch.amax(out_mask_averaged, dim=(2, 3)).squeeze() + + # prevent crash when image_score is a single value (batch size of 1) + if image_score.size() == torch.Size([]): + image_score = image_score.unsqueeze(0) + + out_mask_cv = out_mask_sm_up[:, 1, :, :] + + outputs = {"anomaly_map": out_mask_cv, "pred_score": image_score} + + elif anomaly_map_to_generate is not None and self.training: + # we are in phase two + + # Generate anomaly strength factors + anom_str_lo = ( + torch.rand(batch.shape[0]) * (1.0 - self.latent_anomaly_strength) + self.latent_anomaly_strength + ).cuda() + anom_str_hi = ( + torch.rand(batch.shape[0]) * (1.0 - self.latent_anomaly_strength) + self.latent_anomaly_strength + ).cuda() + + # Generate image through general object decoder, and defective & non defective quantized feature maps. + with torch.no_grad(): + latent_model_outputs = self.discrete_latent_model( + batch, + anomaly_map_to_generate, + anom_str_lo, + anom_str_hi, + ) + gen_image_def = latent_model_outputs["recon_image"] + true_anomaly_map = latent_model_outputs["anomaly_mask"] + embd_top = latent_model_outputs["quantized_t"] + embd_bot = latent_model_outputs["quantized_b"] + embd_top_def = latent_model_outputs["anomaly_embedding_lo"] + embd_bot_def = latent_model_outputs["anomaly_embedding_hi"] + + # Restore the features to normality with the Subspace restriction modules + recon_feat_hi, recon_embeddings_hi = self.subspace_restriction_module_hi( + embd_bot_def, + self.discrete_latent_model.vq_vae_bot, + ) + recon_feat_lo, recon_embeddings_lo = self.subspace_restriction_module_lo( + embd_top_def, + self.discrete_latent_model.vq_vae_top, + ) + + # Reconstruct the image from the reconstructed features + # with the object-specific image reconstruction module + up_quantized_recon_t = self.discrete_latent_model.upsample_t(recon_embeddings_lo) + quant_join = torch.cat((up_quantized_recon_t, recon_embeddings_hi), dim=1) + spec_image_def = self.image_reconstruction_network(quant_join) + + # Generate the anomaly segmentation map + out_mask = self.anomaly_detection_module(spec_image_def.detach(), gen_image_def.detach()) + out_mask_sm = torch.softmax(out_mask, dim=1) + + # Outputs + outputs = { + "recon_feat_hi": recon_feat_hi, + "recon_feat_lo": recon_feat_lo, + "embedding_bot": embd_bot, + "embedding_top": embd_top, + "obj_spec_image": spec_image_def, + "anomaly_map": out_mask_sm, + "true_anomaly_map": true_anomaly_map, + } + else: + msg = "There should not be an anomaly map to generate when not training" + raise RuntimeError(msg) + + return outputs + + +class SubspaceRestrictionModule(nn.Module): + """Subspace Restriction Module. + + Subspace restriction module that restricts the appearance subspace into configurations + that agree with normal appearances and applies quantization. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, base_width: int) -> None: + super().__init__() + + self.unet = SubspaceRestrictionNetwork(in_channels=base_width, out_channels=base_width, base_width=base_width) + + def forward(self, batch: torch.Tensor, quantization: Callable) -> tuple[torch.Tensor, torch.Tensor]: + """Generate the quantized anomaly-free representation of an anomalous image. + + Args: + batch (torch.Tensor): Batch of input images. + quantization (function | object): Quantization function. + + Returns: + Reconstructed batch of non-quantized features and corresponding quantized features. + """ + batch = self.unet(batch) + quantized_b = quantization(batch) + return batch, quantized_b + + +class SubspaceRestrictionNetwork(nn.Module): + """Subspace Restriction Network. + + Subspace restriction network that reconstructs the input image into a + non-quantized configuration that agrees with normal appearances. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int = 64, out_channels: int = 64, base_width: int = 64) -> None: + super().__init__() + self.base_width = base_width + self.encoder = FeatureEncoder(in_channels, self.base_width) + self.decoder = FeatureDecoder(self.base_width, out_channels=out_channels) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Reconstruct non-quantized representation from batch. + + Generate non-quantized feature maps from potentially anomalous images, to + be quantized into non-anomalous quantized representations. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Reconstructed non-quantized representation. + """ + b1, b2, b3 = self.encoder(batch) + return self.decoder(b1, b2, b3) + + +class FeatureEncoder(nn.Module): + """Feature encoder for the subspace restriction network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, base_width: int) -> None: + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + + def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode a batch of input features to the latent space. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Encoded feature maps. + """ + b1 = self.block1(batch) + mp1 = self.mp1(b1) + b2 = self.block2(mp1) + mp2 = self.mp2(b2) + b3 = self.block3(mp2) + return b1, b2, b3 + + +class FeatureDecoder(nn.Module): + """Feature decoder for the subspace restriction network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + """ + + def __init__(self, base_width: int, out_channels: int = 1) -> None: + super().__init__() + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.db2 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db3 = nn.Sequential( + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.InstanceNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward(self, _: torch.Tensor, __: torch.Tensor, b3: torch.Tensor) -> torch.Tensor: + """Decode a batch of latent features to a non-quantized representation. + + Args: + _ (torch.Tensor): Top latent feature layer. + __ (torch.Tensor): Middle latent feature layer. + b3 (torch.Tensor): Bottom latent feature layer. + + Returns: + Decoded non-quantized representation. + """ + up2 = self.up2(b3) + db2 = self.db2(up2) + + up3 = self.up3(db2) + db3 = self.db3(up3) + + return self.fin_out(db3) + + +class Residual(nn.Module): + """Residual layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_residual_hiddens (int): Number of intermediate channels. + """ + + def __init__(self, in_channels: int, out_channels: int, num_residual_hiddens: int) -> None: + super().__init__() + self._block = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=in_channels, + out_channels=num_residual_hiddens, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=num_residual_hiddens, out_channels=out_channels, kernel_size=1, stride=1, bias=False), + ) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Compute residual layer. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Computed feature maps. + """ + return batch + self._block(batch) + + +class ResidualStack(nn.Module): + """Stack of residual layers. + + Args: + in_channels (int): Number of input channels. + num_hiddens (int): Number of output channels in residual layers. + num_residual_layers (int): Number of residual layers. + num_residual_hiddens (int): Number of intermediate channels. + """ + + def __init__(self, in_channels: int, num_hiddens: int, num_residual_layers: int, num_residual_hiddens: int) -> None: + super().__init__() + self._num_residual_layers = num_residual_layers + self._layers = nn.ModuleList( + [Residual(in_channels, num_hiddens, num_residual_hiddens) for _ in range(self._num_residual_layers)], + ) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Compute residual stack. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Computed feature maps. + """ + for i in range(self._num_residual_layers): + batch = self._layers[i](batch) + return F.relu(batch) + + +class ImageReconstructionNetwork(nn.Module): + """Image Reconstruction Network. + + Image reconstruction network that reconstructs the image from a quantized + representation. + + Args: + in_channels (int): Number of input channels. + num_hiddens (int): Number of output channels in residual layers. + num_residual_layers (int): Number of residual layers. + num_residual_hiddens (int): Number of intermediate channels. + """ + + def __init__(self, in_channels: int, num_hiddens: int, num_residual_layers: int, num_residual_hiddens: int) -> None: + super().__init__() + norm_layer = nn.InstanceNorm2d + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + norm_layer(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, padding=1), + norm_layer(in_channels * 2), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), + norm_layer(in_channels * 2), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=3, padding=1), + norm_layer(in_channels * 4), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + + self.pre_vq_conv = nn.Conv2d(in_channels=in_channels * 4, out_channels=64, kernel_size=1, stride=1) + + self.upblock1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1) + + self.upblock2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1) + + self._conv_1 = nn.Conv2d(in_channels=64, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) + + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + self._conv_trans_1 = nn.ConvTranspose2d( + in_channels=num_hiddens, + out_channels=num_hiddens // 2, + kernel_size=4, + stride=2, + padding=1, + ) + + self._conv_trans_2 = nn.ConvTranspose2d( + in_channels=num_hiddens // 2, + out_channels=3, + kernel_size=4, + stride=2, + padding=1, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Reconstructs an image from a quantized representation. + + Args: + inputs (torch.Tensor): Quantized features. + + Returns: + Reconstructed image. + """ + batch = self.block1(inputs) + batch = self.mp1(batch) + batch = self.block2(batch) + batch = self.mp2(batch) + batch = self.pre_vq_conv(batch) + + batch = self.upblock1(batch) + batch = F.relu(batch) + batch = self.upblock2(batch) + batch = F.relu(batch) + batch = self._conv_1(batch) + + batch = self._residual_stack(batch) + + batch = self._conv_trans_1(batch) + batch = F.relu(batch) + + return self._conv_trans_2(batch) + + +class UnetEncoder(nn.Module): + """Encoder of the Unet network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, base_width: int) -> None: + super().__init__() + norm_layer = nn.InstanceNorm2d + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + norm_layer(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + norm_layer(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + norm_layer(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + norm_layer(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + ) + self.mp3 = nn.Sequential(nn.MaxPool2d(2)) + self.block4 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + ) + + def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Encodes batch of images into a latent representation. + + Args: + batch (torch.Tensor): Quantized features. + + Returns: + Latent representations of the input batch. + """ + b1 = self.block1(batch) + mp1 = self.mp1(b1) + b2 = self.block2(mp1) + mp2 = self.mp2(b2) + b3 = self.block3(mp2) + mp3 = self.mp3(b3) + b4 = self.block4(mp3) + return b1, b2, b3, b4 + + +class UnetDecoder(nn.Module): + """Decoder of the Unet network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + """ + + def __init__(self, base_width: int, out_channels: int = 1) -> None: + super().__init__() + norm_layer = nn.InstanceNorm2d + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + ) + # cat with base*4 + self.db1 = nn.Sequential( + nn.Conv2d(base_width * (4 + 4), base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + norm_layer(base_width * 4), + nn.ReLU(inplace=True), + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + norm_layer(base_width * 2), + nn.ReLU(inplace=True), + ) + # cat with base*2 + self.db2 = nn.Sequential( + nn.Conv2d(base_width * (2 + 2), base_width * 2, kernel_size=3, padding=1), + norm_layer(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + norm_layer(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + norm_layer(base_width), + nn.ReLU(inplace=True), + ) + # cat with base*1 + self.db3 = nn.Sequential( + nn.Conv2d(base_width * (1 + 1), base_width, kernel_size=3, padding=1), + norm_layer(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + norm_layer(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward(self, b1: torch.Tensor, b2: torch.Tensor, b3: torch.Tensor, b4: torch.Tensor) -> torch.Tensor: + """Decodes latent represnetations into an image. + + Args: + b1 (torch.Tensor): First (top level) quantized feature map. + b2 (torch.Tensor): Second quantized feature map. + b3 (torch.Tensor): Third quantized feature map. + b4 (torch.Tensor): Fourth (bottom level) quantized feature map. + + Returns: + Reconstructed image. + """ + up1 = self.up1(b4) + cat1 = torch.cat((up1, b3), dim=1) + db1 = self.db1(cat1) + + up2 = self.up2(db1) + cat2 = torch.cat((up2, b2), dim=1) + db2 = self.db2(cat2) + + up3 = self.up3(db2) + cat3 = torch.cat((up3, b1), dim=1) + db3 = self.db3(cat3) + + return self.fin_out(db3) + + +class UnetModel(nn.Module): + """Autoencoder model that reconstructs the input image. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int = 64, out_channels: int = 64, base_width: int = 64) -> None: + super().__init__() + self.encoder = UnetEncoder(in_channels, base_width) + self.decoder = UnetDecoder(base_width, out_channels=out_channels) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Reconstructs an input batch of images. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Reconstructed images. + """ + b1, b2, b3, b4 = self.encoder(batch) + return self.decoder(b1, b2, b3, b4) + + +class AnomalyDetectionModule(nn.Module): + """Anomaly detection module. + + Module that detects the preseßnce of an anomaly by comparing two images reconstructed by + the object specific decoder and the general object decoder. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, out_channels: int, base_width: int) -> None: + super().__init__() + self.unet = UnetModel(in_channels, out_channels, base_width) + + def forward(self, batch_real: torch.Tensor, batch_anomaly: torch.Tensor) -> torch.Tensor: + """Computes the anomaly map over corresponding real and anomalous images. + + Args: + batch_real (torch.Tensor): Batch of real, non defective images. + batch_anomaly (torch.Tensor): Batch of potentially anomalous images. + + Returns: + The anomaly segmentation map. + """ + img_x = torch.cat((batch_real, batch_anomaly), dim=1) + return self.unet(img_x) + + +class UpsamplingModule(nn.Module): + """Module that upsamples the generated anomaly mask to full resolution. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int = 8, out_channels: int = 2, base_width: int = 64) -> None: + super().__init__() + self.unet = UnetModel(in_channels, out_channels, base_width) + + def forward( + self, + batch_real: torch.Tensor, + batch_anomaly: torch.Tensor, + batch_segmentation_map: torch.Tensor, + ) -> torch.Tensor: + """Computes upsampled segmentation maps. + + Args: + batch_real (torch.Tensor): Batch of real, non defective images. + batch_anomaly (torch.Tensor): Batch of potentially anomalous images. + batch_segmentation_map (torch.Tensor): Batch of anomaly segmentation maps. + + Returns: + Upsampled anomaly segmentation maps. + """ + img_x = torch.cat((batch_real, batch_anomaly, batch_segmentation_map), dim=1) + return self.unet(img_x) + + +class VectorQuantizer(nn.Module): + """Module that quantizes a given feature map using learned quantization codebooks. + + Args: + num_embeddings (int): Size of embedding codebook. + embedding_dim (int): Dimension of embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: + # Source for the VectorQuantizer module: https://github.com/zalandoresearch/pytorch-vq-vae + super().__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.normal_() + + # necessary to correctly load the checkpoint file + self.register_buffer("_ema_cluster_size", torch.zeros(num_embeddings)) + self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) + self._ema_w.data.normal_() + + @property + def embedding(self) -> torch.Tensor: + """Return embedding.""" + return self._embedding + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Calculates quantized feature map. + + Args: + inputs (torch.Tensor): Non-quantized feature maps. + + Returns: + Quantized feature maps. + """ + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = ( + torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t()) + ) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + quantized = inputs + (quantized - inputs).detach() + + # convert quantized from BHWC -> BCHW + return quantized.permute(0, 3, 1, 2).contiguous() + + +class EncoderBot(nn.Module): + """Encoder module for bottom quantized feature maps. + + Args: + in_channels (int): Number of input channels. + num_hiddens (int): Number of hidden channels. + num_residual_layers (int): Number of residual layers in residual stacks. + num_residual_hiddens (int): Number of channels in residual layers. + """ + + def __init__(self, in_channels: int, num_hiddens: int, num_residual_layers: int, num_residual_hiddens: int) -> None: + super().__init__() + + self._conv_1 = nn.Conv2d( + in_channels=in_channels, + out_channels=num_hiddens // 2, + kernel_size=4, + stride=2, + padding=1, + ) + self._conv_2 = nn.Conv2d( + in_channels=num_hiddens // 2, + out_channels=num_hiddens, + kernel_size=4, + stride=2, + padding=1, + ) + self._conv_3 = nn.Conv2d(in_channels=num_hiddens, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Encode inputs to be quantized into the bottom feature map. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Encoded feature maps. + """ + x = self._conv_1(batch) + x = F.relu(x) + + x = self._conv_2(x) + x = F.relu(x) + + x = self._conv_3(x) + return self._residual_stack(x) + + +class EncoderTop(nn.Module): + """Encoder module for top quantized feature maps. + + Args: + in_channels (int): Number of input channels. + num_hiddens (int): Number of hidden channels. + num_residual_layers (int): Number of residual layers in residual stacks. + num_residual_hiddens (int): Number of channels in residual layers. + """ + + def __init__(self, in_channels: int, num_hiddens: int, num_residual_layers: int, num_residual_hiddens: int) -> None: + super().__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens, kernel_size=4, stride=2, padding=1) + self._conv_2 = nn.Conv2d(in_channels=num_hiddens, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Encode inputs to be quantized into the top feature map. + + Args: + batch (torch.Tensor): Batch of input images. + + Returns: + Encoded feature maps. + """ + x = self._conv_1(batch) + x = F.relu(x) + + x = self._conv_2(x) + x = F.relu(x) + + return self._residual_stack(x) + + +class DecoderBot(nn.Module): + """General appearance decoder module to reconstruct images while keeping possible anomalies. + + Args: + in_channels (int): Number of input channels. + num_hiddens (int): Number of hidden channels. + num_residual_layers (int): Number of residual layers in residual stack. + num_residual_hiddens (int): Number of channels in residual layers. + """ + + def __init__(self, in_channels: int, num_hiddens: int, num_residual_layers: int, num_residual_hiddens: int) -> None: + super().__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) + + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + self._conv_trans_1 = nn.ConvTranspose2d( + in_channels=num_hiddens, + out_channels=num_hiddens // 2, + kernel_size=4, + stride=2, + padding=1, + ) + + self._conv_trans_2 = nn.ConvTranspose2d( + in_channels=num_hiddens // 2, + out_channels=3, + kernel_size=4, + stride=2, + padding=1, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Decode quantized feature maps into an image. + + Args: + inputs (torch.Tensor): Quantized feature maps. + + Returns: + Decoded image. + """ + x = self._conv_1(inputs) + + x = self._residual_stack(x) + + x = self._conv_trans_1(x) + x = F.relu(x) + + return self._conv_trans_2(x) + + +class DiscreteLatentModel(nn.Module): + """Discrete Latent Model. + + Autoencoder quantized model that encodes the input images into quantized feature maps and generates + a reconstructed image using the general appearance decoder. + + Args: + num_hiddens (int): Number of hidden channels. + num_residual_layers (int): Number of residual layers in residual stacks. + num_residual_hiddens (int): Number of channels in residual layers. + num_embeddings (int): Size of embedding dictionary. + embedding_dim (int): Dimension of embeddings. + """ + + def __init__( + self, + num_hiddens: int, + num_residual_layers: int, + num_residual_hiddens: int, + num_embeddings: int, + embedding_dim: int, + ) -> None: + super().__init__() + + self._encoder_t = EncoderTop(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens) + + self._encoder_b = EncoderBot(3, num_hiddens, num_residual_layers, num_residual_hiddens) + + self._pre_vq_conv_bot = nn.Conv2d( + in_channels=num_hiddens + embedding_dim, + out_channels=embedding_dim, + kernel_size=1, + stride=1, + ) + + self._pre_vq_conv_top = nn.Conv2d(in_channels=num_hiddens, out_channels=embedding_dim, kernel_size=1, stride=1) + + self._vq_vae_top = VectorQuantizer(num_embeddings, embedding_dim) + + self._vq_vae_bot = VectorQuantizer(num_embeddings, embedding_dim) + + self._decoder_b = DecoderBot(embedding_dim * 2, num_hiddens, num_residual_layers, num_residual_hiddens) + + self.upsample_t = nn.ConvTranspose2d(embedding_dim, embedding_dim, 4, stride=2, padding=1) + + @property + def vq_vae_top(self) -> VectorQuantizer: + """Return ``self._vq_vae_top``.""" + return self._vq_vae_top + + @property + def vq_vae_bot(self) -> VectorQuantizer: + """Return ``self._vq_vae_bot``.""" + return self._vq_vae_bot + + def generate_fake_anomalies_joined( + self, + features: torch.Tensor, + embeddings: torch.Tensor, + memory_torch_original: torch.Tensor, + mask: torch.Tensor, + strength: torch.Tensor, + ) -> torch.Tensor: + """Generate quantized anomalies. + + Args: + features (torch.Tensor): Features on which the anomalies will be generated. + embeddings (torch.Tensor): Embeddings to use to generate the anomalies. + memory_torch_original (torch.Tensor): Weight of embeddings. + mask (torch.Tensor): Original anomaly mask. + strength (float): Strength of generated anomaly. + + Returns: + torch.Tensor: Anomalous embedding. + """ + random_embeddings = torch.zeros( + (embeddings.shape[0], embeddings.shape[2] * embeddings.shape[3], memory_torch_original.shape[1]), + ) + inputs = features.permute(0, 2, 3, 1).contiguous() + + for k in range(embeddings.shape[0]): + memory_torch = memory_torch_original + flat_input = inputs[k].view(-1, memory_torch.shape[1]) + + distances_b = ( + torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(memory_torch**2, dim=1) + - 2 * torch.matmul(flat_input, memory_torch.t()) + ) + + percentage_vectors = strength[k] + topk = max(1, min(int(percentage_vectors * memory_torch.shape[0]) + 1, memory_torch.shape[0] - 1)) + _, topk_indices = torch.topk(distances_b, topk, dim=1, largest=False) + topk_indices = topk_indices[:, int(memory_torch.shape[0] * 0.05) :] + topk = topk_indices.shape[1] + + random_indices_hik = torch.randint(topk, size=(topk_indices.shape[0],)) + random_indices_t = topk_indices[torch.arange(random_indices_hik.shape[0]), random_indices_hik] + random_embeddings[k] = memory_torch[random_indices_t, :] + random_embeddings = random_embeddings.reshape( + (random_embeddings.shape[0], embeddings.shape[2], embeddings.shape[3], random_embeddings.shape[2]), + ) + random_embeddings_tensor = random_embeddings.permute(0, 3, 1, 2).cuda() + + down_ratio_y = int(mask.shape[2] / embeddings.shape[2]) + down_ratio_x = int(mask.shape[3] / embeddings.shape[3]) + anomaly_mask = torch.nn.functional.max_pool2d(mask, (down_ratio_y, down_ratio_x)).float() + + return anomaly_mask * random_embeddings_tensor + (1.0 - anomaly_mask) * embeddings + + def forward( + self, + batch: torch.Tensor, + anomaly_mask: torch.Tensor | None = None, + anom_str_lo: torch.Tensor | None = None, + anom_str_hi: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Generate quantized feature maps. + + Generates quantized feature maps of batch of input images as well as their + reconstruction based on the general appearance decoder. + + Args: + batch (Tensor): Batch of input images. + anomaly_mask (Tensor | None): Anomaly mask to be used to generate anomalies on + the quantized feature maps. + anom_str_lo (torch.Tensor | None): Strength of generated anomaly lo. + anom_str_hi (torch.Tensor | None): Strength of generated anomaly hi. + + Returns: + dict[str, torch.Tensor]: + If generating an anomaly mask: + - General object decoder-decoded anomalous image + - Reshaped ground truth anomaly map + - Non defective quantized lo feature + - Non defective quantized hi feature + - Non quantized subspace encoded defective lo feature + - Non quantized subspace encoded defective hi feature + Else: + - General object decoder-decoded image + - Quantized lo feature + - Quantized hi feature + """ + # Encoder Hi + enc_b = self._encoder_b(batch) + + # Encoder Lo -- F_Lo + enc_t = self._encoder_t(enc_b) + zt = self._pre_vq_conv_top(enc_t) + + # Quantize F_Lo with K_Lo + quantized_t = self._vq_vae_top(zt) + + # Upsample Q_Lo + up_quantized_t = self.upsample_t(quantized_t) + + # Concatenate and transform the output of Encoder_Hi and upsampled Q_lo -- F_Hi + feat = torch.cat((enc_b, up_quantized_t), dim=1) + zb = self._pre_vq_conv_bot(feat) + + # Quantize F_Hi with K_Hi + quantized_b = self._vq_vae_bot(zb) + + # generate anomalies + anomaly_embedding_hi = None + anomaly_embedding_lo = None + + # define outputs + outputs = {"quantized_b": quantized_b, "quantized_t": quantized_t} + + if anomaly_mask is not None: + # Generate feature-based anomalies on F_lo + anomaly_embedding_lo = self.generate_fake_anomalies_joined( + zt, + quantized_t, + self._vq_vae_top.embedding.weight, + anomaly_mask, + anom_str_lo, + ) + + up_quantized_t_defect = self.upsample_t(anomaly_embedding_lo) + feat_defect = torch.cat((enc_b, up_quantized_t_defect), dim=1) + zb_defect = self._pre_vq_conv_bot(feat_defect) + quantized_b_defect = self._vq_vae_bot(zb_defect) + + # Generate feature-based anomalies on F_hi + anomaly_embedding_hi = self.generate_fake_anomalies_joined( + zb_defect, + quantized_b_defect, + self._vq_vae_bot.embedding.weight, + anomaly_mask, + anom_str_hi, + ) + + # get anomaly embeddings + use_both = torch.randint(0, 2, (batch.shape[0], 1, 1, 1)).cuda().float() + use_lo = torch.randint(0, 2, (batch.shape[0], 1, 1, 1)).cuda().float() + use_hi = 1 - use_lo + + anomaly_embedding_hi_usebot = self.generate_fake_anomalies_joined( + zb, + quantized_b, + self._vq_vae_bot.embedding.weight, + anomaly_mask, + anom_str_hi, + ) + + anomaly_embedding_lo_usebot = quantized_t + anomaly_embedding_hi_usetop = quantized_b + anomaly_embedding_lo_usetop = anomaly_embedding_lo + anomaly_embedding_hi_not_both = use_hi * anomaly_embedding_hi_usebot + use_lo * anomaly_embedding_hi_usetop + anomaly_embedding_lo_not_both = use_hi * anomaly_embedding_lo_usebot + use_lo * anomaly_embedding_lo_usetop + anomaly_embedding_hi = ( + (anomaly_embedding_hi * use_both + anomaly_embedding_hi_not_both * (1.0 - use_both)).detach().clone() + ) + anomaly_embedding_lo = ( + (anomaly_embedding_lo * use_both + anomaly_embedding_lo_not_both * (1.0 - use_both)).detach().clone() + ) + + anomaly_embedding_hi_copy = anomaly_embedding_hi.clone() + anomaly_embedding_lo_copy = anomaly_embedding_lo.clone() + + # apply the general appearance decoder to the anomaly embeddings + up_quantized_anomaly_t = self.upsample_t(anomaly_embedding_lo_copy) + quant_join_anomaly = torch.cat((up_quantized_anomaly_t, anomaly_embedding_hi_copy), dim=1) + recon_image = self._decoder_b(quant_join_anomaly) + + # Resize the ground truth anomaly map to closely match the augmented features + down_ratio_x_hi = int(anomaly_mask.shape[3] / quantized_b_defect.shape[3]) + anomaly_mask_hi = torch.nn.functional.max_pool2d(anomaly_mask, (down_ratio_x_hi, down_ratio_x_hi)).float() + anomaly_mask_hi = torch.nn.functional.interpolate(anomaly_mask_hi, scale_factor=down_ratio_x_hi) + down_ratio_x_lo = int(anomaly_mask.shape[3] / quantized_t.shape[3]) + anomaly_mask_lo = torch.nn.functional.max_pool2d(anomaly_mask, (down_ratio_x_lo, down_ratio_x_lo)).float() + anomaly_mask_lo = torch.nn.functional.interpolate(anomaly_mask_lo, scale_factor=down_ratio_x_lo) + anomaly_mask = anomaly_mask_lo * use_both + (anomaly_mask_lo * use_lo + anomaly_mask_hi * use_hi) * ( + 1.0 - use_both + ) + + # reminder : top = lo, bot = hi! + outputs["recon_image"] = recon_image + outputs["anomaly_mask"] = anomaly_mask + outputs["anomaly_embedding_lo"] = anomaly_embedding_lo + outputs["anomaly_embedding_hi"] = anomaly_embedding_hi + + else: + # Concatenate Q_Hi and Q_Lo and input it into the General appearance decoder + quant_join = torch.cat((up_quantized_t, quantized_b), dim=1) + recon_image = self._decoder_b(quant_join) + + outputs["recon_image"] = recon_image + + return outputs diff --git a/anomalib/models/image/efficient_ad/README.md b/anomalib/models/image/efficient_ad/README.md new file mode 100644 index 0000000000000000000000000000000000000000..67da16fd953c4896387ed167b6d9a5574cabffa3 --- /dev/null +++ b/anomalib/models/image/efficient_ad/README.md @@ -0,0 +1,41 @@ +# EfficientAd + +This is the implementation of the [EfficientAd](https://arxiv.org/pdf/2303.14535.pdf) paper. It is based on https://github.com/rximg/EfficientAd and https://github.com/nelson1425/EfficientAd/ + +Model Type: Segmentation + +## Description + +Fast anomaly segmentation algorithm that consists of a distilled pre-trained teacher model, a student model and an autoencoder. It detects local anomalies via the teacher-student discrepany and global anomalies via the student-autoencoder discrepancy. + +### Feature Extraction + +Features are extracted from a pre-trained teacher model and used to train a student model and an autoencoder model. To hinder the student from imitating the teacher on anomalies, Imagenet images are used in the loss function. + +### Anomaly Detection + +Anomalies are detected as the difference in output feature maps between the teacher model, the student model and the autoencoder model. + +## Usage + +`anomalib train --model EfficientAd --data anomalib.data.MVTec --data.train_batch_size 1` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| ------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| EfficientAd-S | 0.982 | 0.982 | 1.000 | 0.997 | 1.000 | 0.986 | 1.000 | 0.952 | 0.950 | 0.952 | 0.979 | 0.987 | 0.960 | 0.997 | 0.999 | 0.994 | +| EfficientAd-M | 0.975 | 0.972 | 0.998 | 1.000 | 0.999 | 0.984 | 0.991 | 0.945 | 0.957 | 0.948 | 0.989 | 0.926 | 0.975 | 1.000 | 0.965 | 0.971 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| ------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| EfficientAd-S | 0.970 | 0.966 | 1.000 | 0.995 | 1.000 | 0.975 | 1.000 | 0.907 | 0.956 | 0.897 | 0.978 | 0.982 | 0.944 | 0.984 | 0.988 | 0.983 | +| EfficientAd-M | 0.966 | 0.977 | 0.991 | 1.000 | 0.994 | 0.967 | 0.984 | 0.922 | 0.969 | 0.884 | 0.984 | 0.952 | 0.955 | 1.000 | 0.929 | 0.979 | diff --git a/anomalib/models/image/efficient_ad/__init__.py b/anomalib/models/image/efficient_ad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b6f5f2b002db8d32843e48f5160409e673feb2 --- /dev/null +++ b/anomalib/models/image/efficient_ad/__init__.py @@ -0,0 +1,11 @@ +"""EfficientAd: Accurate Visual Anomaly Detection at Millisecond-Level Latencies. + +https://arxiv.org/pdf/2303.14535.pdf. +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import EfficientAd + +__all__ = ["EfficientAd"] diff --git a/anomalib/models/image/efficient_ad/lightning_model.py b/anomalib/models/image/efficient_ad/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03bc643af36dd72f1325048a6cb69e05d513c9b9 --- /dev/null +++ b/anomalib/models/image/efficient_ad/lightning_model.py @@ -0,0 +1,329 @@ +"""EfficientAd: Accurate Visual Anomaly Detection at Millisecond-Level Latencies. + +https://arxiv.org/pdf/2303.14535.pdf. +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from pathlib import Path +from typing import Any + +import torch +import tqdm +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, RandomGrayscale, Resize, ToTensor, Transform + +from anomalib import LearningType +from anomalib.data.utils import DownloadInfo, download_and_extract +from anomalib.models.components import AnomalyModule + +from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems + +logger = logging.getLogger(__name__) + +IMAGENETTE_DOWNLOAD_INFO = DownloadInfo( + name="imagenette2.tgz", + url="https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", + hashsum="6cbfac238434d89fe99e651496f0812ebc7a10fa62bd42d6874042bf01de4efd", +) + +WEIGHTS_DOWNLOAD_INFO = DownloadInfo( + name="efficientad_pretrained_weights.zip", + url="https://github.com/openvinotoolkit/anomalib/releases/download/efficientad_pretrained_weights/efficientad_pretrained_weights.zip", + hashsum="c09aeaa2b33f244b3261a5efdaeae8f8284a949470a4c5a526c61275fe62684a", +) + + +class EfficientAd(AnomalyModule): + """PL Lightning Module for the EfficientAd algorithm. + + Args: + imagenet_dir (Path|str): directory path for the Imagenet dataset + Defaults to ``./datasets/imagenette``. + teacher_out_channels (int): number of convolution output channels + Defaults to ``384``. + model_size (str): size of student and teacher model + Defaults to ``EfficientAdModelSize.S``. + lr (float): learning rate + Defaults to ``0.0001``. + weight_decay (float): optimizer weight decay + Defaults to ``0.00001``. + padding (bool): use padding in convoluional layers + Defaults to ``False``. + pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the + output anomaly maps so that their size matches the size in the padding = True case. + Defaults to ``True``. + """ + + def __init__( + self, + imagenet_dir: Path | str = "./datasets/imagenette", + teacher_out_channels: int = 384, + model_size: EfficientAdModelSize = EfficientAdModelSize.S, + lr: float = 0.0001, + weight_decay: float = 0.00001, + padding: bool = False, + pad_maps: bool = True, + ) -> None: + super().__init__() + + self.imagenet_dir = Path(imagenet_dir) + self.model_size = model_size + self.model: EfficientAdModel = EfficientAdModel( + teacher_out_channels=teacher_out_channels, + model_size=model_size, + padding=padding, + pad_maps=pad_maps, + ) + self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper + self.lr = lr + self.weight_decay = weight_decay + + def prepare_pretrained_model(self) -> None: + """Prepare the pretrained teacher model.""" + pretrained_models_dir = Path("./pre_trained/") + if not (pretrained_models_dir / "efficientad_pretrained_weights").is_dir(): + download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) + teacher_path = ( + pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{self.model_size.value}.pth" + ) + logger.info(f"Load pretrained teacher model from {teacher_path}") + self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device))) + + def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> None: + """Prepare ImageNette dataset transformations. + + Args: + image_size (tuple[int, int] | torch.Size): Image size. + """ + self.data_transforms_imagenet = Compose( + [ + Resize((image_size[0] * 2, image_size[1] * 2)), + RandomGrayscale(p=0.3), + CenterCrop((image_size[0], image_size[1])), + ToTensor(), + ], + ) + + if not self.imagenet_dir.is_dir(): + download_and_extract(self.imagenet_dir, IMAGENETTE_DOWNLOAD_INFO) + imagenet_dataset = ImageFolder(self.imagenet_dir, transform=self.data_transforms_imagenet) + self.imagenet_loader = DataLoader(imagenet_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) + self.imagenet_iterator = iter(self.imagenet_loader) + + @torch.no_grad() + def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, torch.Tensor]: + """Calculate the mean and std of the teacher models activations. + + Adapted from https://math.stackexchange.com/a/2148949 + + Args: + dataloader (DataLoader): Dataloader of the respective dataset. + + Returns: + dict[str, torch.Tensor]: Dictionary of channel-wise mean and std + """ + arrays_defined = False + n: torch.Tensor | None = None + chanel_sum: torch.Tensor | None = None + chanel_sum_sqr: torch.Tensor | None = None + + for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean & std", position=0, leave=True): + y = self.model.teacher(batch["image"].to(self.device)) + if not arrays_defined: + _, num_channels, _, _ = y.shape + n = torch.zeros((num_channels,), dtype=torch.int64, device=y.device) + chanel_sum = torch.zeros((num_channels,), dtype=torch.float32, device=y.device) + chanel_sum_sqr = torch.zeros((num_channels,), dtype=torch.float32, device=y.device) + arrays_defined = True + + n += y[:, 0].numel() + chanel_sum += torch.sum(y, dim=[0, 2, 3]) + chanel_sum_sqr += torch.sum(y**2, dim=[0, 2, 3]) + + if n is None: + msg = "The value of 'n' cannot be None." + raise ValueError(msg) + + channel_mean = chanel_sum / n + + channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean**2))).float()[None, :, None, None] + channel_mean = channel_mean.float()[None, :, None, None] + + return {"mean": channel_mean, "std": channel_std} + + @torch.no_grad() + def map_norm_quantiles(self, dataloader: DataLoader) -> dict[str, torch.Tensor]: + """Calculate 90% and 99.5% quantiles of the student(st) and autoencoder(ae). + + Args: + dataloader (DataLoader): Dataloader of the respective dataset. + + Returns: + dict[str, torch.Tensor]: Dictionary of both the 90% and 99.5% quantiles + of both the student and autoencoder feature maps. + """ + maps_st = [] + maps_ae = [] + logger.info("Calculate Validation Dataset Quantiles") + for batch in tqdm.tqdm(dataloader, desc="Calculate Validation Dataset Quantiles", position=0, leave=True): + for img, label in zip(batch["image"], batch["label"], strict=True): + if label == 0: # only use good images of validation set! + output = self.model(img.to(self.device), normalize=False) + map_st = output["map_st"] + map_ae = output["map_ae"] + maps_st.append(map_st) + maps_ae.append(map_ae) + + qa_st, qb_st = self._get_quantiles_of_maps(maps_st) + qa_ae, qb_ae = self._get_quantiles_of_maps(maps_ae) + return {"qa_st": qa_st, "qa_ae": qa_ae, "qb_st": qb_st, "qb_ae": qb_ae} + + def _get_quantiles_of_maps(self, maps: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate 90% and 99.5% quantiles of the given anomaly maps. + + If the total number of elements in the given maps is larger than 16777216 + the returned quantiles are computed on a random subset of the given + elements. + + Args: + maps (list[torch.Tensor]): List of anomaly maps. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Two scalars - the 90% and the 99.5% quantile. + """ + maps_flat = reduce_tensor_elems(torch.cat(maps)) + qa = torch.quantile(maps_flat, q=0.9).to(self.device) + qb = torch.quantile(maps_flat, q=0.995).to(self.device) + return qa, qb + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers.""" + optimizer = torch.optim.Adam( + list(self.model.student.parameters()) + list(self.model.ae.parameters()), + lr=self.lr, + weight_decay=self.weight_decay, + ) + + if self.trainer.max_epochs < 0 and self.trainer.max_steps < 0: + msg = "A finite number of steps or epochs must be defined" + raise ValueError(msg) + + # lightning stops training when either 'max_steps' or 'max_epochs' is reached (earliest), + # so actual training steps need to be determined here + if self.trainer.max_epochs < 0: + # max_epochs not set + num_steps = self.trainer.max_steps + elif self.trainer.max_steps < 0: + # max_steps not set -> determine steps as 'max_epochs' * 'steps in a single training epoch' + num_steps = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) + else: + num_steps = min( + self.trainer.max_steps, + self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()), + ) + + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_steps), gamma=0.1) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + def on_train_start(self) -> None: + """Called before the first training epoch. + + First check if EfficientAd-specific parameters are set correctly (train_batch_size of 1 + and no Imagenet normalization in transforms), then sets up the pretrained teacher model, + then prepares the imagenette data, and finally calculates or loads + the channel-wise mean and std of the training dataset and push to the model. + """ + if self.trainer.datamodule.train_batch_size != 1: + msg = "train_batch_size for EfficientAd should be 1." + raise ValueError(msg) + if self._transform and any(isinstance(transform, Normalize) for transform in self._transform.transforms): + msg = "Transforms for EfficientAd should not contain Normalize." + raise ValueError(msg) + + sample = next(iter(self.trainer.train_dataloader)) + image_size = sample["image"].shape[-2:] + self.prepare_pretrained_model() + self.prepare_imagenette_data(image_size) + if not self.model.is_set(self.model.mean_std): + channel_mean_std = self.teacher_channel_mean_std(self.trainer.datamodule.train_dataloader()) + self.model.mean_std.update(channel_mean_std) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict[str, torch.Tensor]: + """Perform the training step for EfficientAd returns the student, autoencoder and combined loss. + + Args: + batch (batch: dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Loss. + """ + del args, kwargs # These variables are not used. + + try: + # infinite dataloader; [0] getting the image not the label + batch_imagenet = next(self.imagenet_iterator)[0].to(self.device) + except StopIteration: + self.imagenet_iterator = iter(self.imagenet_loader) + batch_imagenet = next(self.imagenet_iterator)[0].to(self.device) + + loss_st, loss_ae, loss_stae = self.model(batch=batch["image"], batch_imagenet=batch_imagenet) + + loss = loss_st + loss_ae + loss_stae + self.log("train_st", loss_st.item(), on_epoch=True, prog_bar=True, logger=True) + self.log("train_ae", loss_ae.item(), on_epoch=True, prog_bar=True, logger=True) + self.log("train_stae", loss_stae.item(), on_epoch=True, prog_bar=True, logger=True) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def on_validation_start(self) -> None: + """Calculate the feature map quantiles of the validation dataset and push to the model.""" + map_norm_quantiles = self.map_norm_quantiles(self.trainer.datamodule.val_dataloader()) + self.model.quantiles.update(map_norm_quantiles) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of EfficientAd returns anomaly maps for the input image batch. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing anomaly maps. + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"])["anomaly_map"] + + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return EfficientAD trainer arguments.""" + return {"num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Default transform for EfficientAd. Imagenet normalization applied in forward.""" + image_size = image_size or (256, 256) + return Compose( + [ + Resize(image_size, antialias=True), + ], + ) diff --git a/anomalib/models/image/efficient_ad/torch_model.py b/anomalib/models/image/efficient_ad/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3768b3980fb1de05573adf758c8ff5dda634ba --- /dev/null +++ b/anomalib/models/image/efficient_ad/torch_model.py @@ -0,0 +1,427 @@ +"""Torch model for student, teacher and autoencoder model in EfficientAd.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import math +from enum import Enum + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 +from torchvision import transforms + +logger = logging.getLogger(__name__) + + +def imagenet_norm_batch(x: torch.Tensor) -> torch.Tensor: + """Normalize batch of images with ImageNet mean and std. + + Args: + x (torch.Tensor): Input batch. + + Returns: + torch.Tensor: Normalized batch using the ImageNet mean and std. + """ + mean = torch.tensor([0.485, 0.456, 0.406])[None, :, None, None].to(x.device) + std = torch.tensor([0.229, 0.224, 0.225])[None, :, None, None].to(x.device) + return (x - mean) / std + + +def reduce_tensor_elems(tensor: torch.Tensor, m: int = 2**24) -> torch.Tensor: + """Reduce tensor elements. + + This function flatten n-dimensional tensors, selects m elements from it + and returns the selected elements as tensor. It is used to select + at most 2**24 for torch.quantile operation, as it is the maximum + supported number of elements. + https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291. + + Args: + tensor (torch.Tensor): input tensor from which elements are selected + m (int): number of maximum tensor elements. + Defaults to ``2**24`` + + Returns: + Tensor: reduced tensor + """ + tensor = torch.flatten(tensor) + if len(tensor) > m: + # select a random subset with m elements. + perm = torch.randperm(len(tensor), device=tensor.device) + idx = perm[:m] + tensor = tensor[idx] + return tensor + + +class EfficientAdModelSize(str, Enum): + """Supported EfficientAd model sizes.""" + + M = "medium" + S = "small" + + +class SmallPatchDescriptionNetwork(nn.Module): + """Patch Description Network small. + + Args: + out_channels (int): number of convolution output channels + padding (bool): use padding in convoluional layers + Defaults to ``False``. + """ + + def __init__(self, out_channels: int, padding: bool = False) -> None: + super().__init__() + pad_mult = 1 if padding else 0 + self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=1, padding=3 * pad_mult) + self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=3 * pad_mult) + self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1 * pad_mult) + self.conv4 = nn.Conv2d(256, out_channels, kernel_size=4, stride=1, padding=0 * pad_mult) + self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) + self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the network. + + Args: + x (torch.Tensor): Input batch. + + Returns: + torch.Tensor: Output from the network. + """ + x = imagenet_norm_batch(x) + x = F.relu(self.conv1(x)) + x = self.avgpool1(x) + x = F.relu(self.conv2(x)) + x = self.avgpool2(x) + x = F.relu(self.conv3(x)) + return self.conv4(x) + + +class MediumPatchDescriptionNetwork(nn.Module): + """Patch Description Network medium. + + Args: + out_channels (int): number of convolution output channels + padding (bool): use padding in convoluional layers + Defaults to ``False``. + """ + + def __init__(self, out_channels: int, padding: bool = False) -> None: + super().__init__() + pad_mult = 1 if padding else 0 + self.conv1 = nn.Conv2d(3, 256, kernel_size=4, stride=1, padding=3 * pad_mult) + self.conv2 = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=3 * pad_mult) + self.conv3 = nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0 * pad_mult) + self.conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1 * pad_mult) + self.conv5 = nn.Conv2d(512, out_channels, kernel_size=4, stride=1, padding=0 * pad_mult) + self.conv6 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0 * pad_mult) + self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) + self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the network. + + Args: + x (torch.Tensor): Input batch. + + Returns: + torch.Tensor: Output from the network. + """ + x = imagenet_norm_batch(x) + x = F.relu(self.conv1(x)) + x = self.avgpool1(x) + x = F.relu(self.conv2(x)) + x = self.avgpool2(x) + x = F.relu(self.conv3(x)) + x = F.relu(self.conv4(x)) + x = F.relu(self.conv5(x)) + return self.conv6(x) + + +class Encoder(nn.Module): + """Autoencoder Encoder model.""" + + def __init__(self) -> None: + super().__init__() + self.enconv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1) + self.enconv2 = nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1) + self.enconv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) + self.enconv4 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1) + self.enconv5 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1) + self.enconv6 = nn.Conv2d(64, 64, kernel_size=8, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform the forward pass through the network. + + Args: + x (torch.Tensor): Input batch. + + Returns: + torch.Tensor: Output from the network. + """ + x = F.relu(self.enconv1(x)) + x = F.relu(self.enconv2(x)) + x = F.relu(self.enconv3(x)) + x = F.relu(self.enconv4(x)) + x = F.relu(self.enconv5(x)) + return self.enconv6(x) + + +class Decoder(nn.Module): + """Autoencoder Decoder model. + + Args: + out_channels (int): number of convolution output channels + padding (int): use padding in convoluional layers + """ + + def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = padding + # use ceil to match output shape of PDN + self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv4 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv5 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv6 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2) + self.deconv7 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.deconv8 = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1) + self.dropout1 = nn.Dropout(p=0.2) + self.dropout2 = nn.Dropout(p=0.2) + self.dropout3 = nn.Dropout(p=0.2) + self.dropout4 = nn.Dropout(p=0.2) + self.dropout5 = nn.Dropout(p=0.2) + self.dropout6 = nn.Dropout(p=0.2) + + def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor: + """Perform a forward pass through the network. + + Args: + x (torch.Tensor): Input batch. + image_size (tuple): size of input images. + + Returns: + torch.Tensor: Output from the network. + """ + last_upsample = ( + math.ceil(image_size[0] / 4) if self.padding else math.ceil(image_size[0] / 4) - 8, + math.ceil(image_size[1] / 4) if self.padding else math.ceil(image_size[1] / 4) - 8, + ) + x = F.interpolate(x, size=(image_size[0] // 64 - 1, image_size[1] // 64 - 1), mode="bilinear") + x = F.relu(self.deconv1(x)) + x = self.dropout1(x) + x = F.interpolate(x, size=(image_size[0] // 32, image_size[1] // 32), mode="bilinear") + x = F.relu(self.deconv2(x)) + x = self.dropout2(x) + x = F.interpolate(x, size=(image_size[0] // 16 - 1, image_size[1] // 16 - 1), mode="bilinear") + x = F.relu(self.deconv3(x)) + x = self.dropout3(x) + x = F.interpolate(x, size=(image_size[0] // 8, image_size[1] // 8), mode="bilinear") + x = F.relu(self.deconv4(x)) + x = self.dropout4(x) + x = F.interpolate(x, size=(image_size[0] // 4 - 1, image_size[1] // 4 - 1), mode="bilinear") + x = F.relu(self.deconv5(x)) + x = self.dropout5(x) + x = F.interpolate(x, size=(image_size[0] // 2 - 1, image_size[1] // 2 - 1), mode="bilinear") + x = F.relu(self.deconv6(x)) + x = self.dropout6(x) + x = F.interpolate(x, size=last_upsample, mode="bilinear") + x = F.relu(self.deconv7(x)) + return self.deconv8(x) + + +class AutoEncoder(nn.Module): + """EfficientAd Autoencoder. + + Args: + out_channels (int): number of convolution output channels + padding (int): use padding in convoluional layers + """ + + def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.encoder = Encoder() + self.decoder = Decoder(out_channels, padding) + + def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor: + """Perform the forward pass through the network. + + Args: + x (torch.Tensor): Input batch. + image_size (tuple): size of input images. + + Returns: + torch.Tensor: Output from the network. + """ + x = imagenet_norm_batch(x) + x = self.encoder(x) + return self.decoder(x, image_size) + + +class EfficientAdModel(nn.Module): + """EfficientAd model. + + Args: + teacher_out_channels (int): number of convolution output channels of the pre-trained teacher model + model_size (str): size of student and teacher model + padding (bool): use padding in convoluional layers + Defaults to ``False``. + pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the + output anomaly maps so that their size matches the size in the padding = True case. + Defaults to ``True``. + """ + + def __init__( + self, + teacher_out_channels: int, + model_size: EfficientAdModelSize = EfficientAdModelSize.S, + padding: bool = False, + pad_maps: bool = True, + ) -> None: + super().__init__() + + self.pad_maps = pad_maps + self.teacher: MediumPatchDescriptionNetwork | SmallPatchDescriptionNetwork + self.student: MediumPatchDescriptionNetwork | SmallPatchDescriptionNetwork + + if model_size == EfficientAdModelSize.M: + self.teacher = MediumPatchDescriptionNetwork(out_channels=teacher_out_channels, padding=padding).eval() + self.student = MediumPatchDescriptionNetwork(out_channels=teacher_out_channels * 2, padding=padding) + + elif model_size == EfficientAdModelSize.S: + self.teacher = SmallPatchDescriptionNetwork(out_channels=teacher_out_channels, padding=padding).eval() + self.student = SmallPatchDescriptionNetwork(out_channels=teacher_out_channels * 2, padding=padding) + + else: + msg = f"Unknown model size {model_size}" + raise ValueError(msg) + + self.ae: AutoEncoder = AutoEncoder(out_channels=teacher_out_channels, padding=padding) + self.teacher_out_channels: int = teacher_out_channels + + self.mean_std: nn.ParameterDict = nn.ParameterDict( + { + "mean": torch.zeros((1, self.teacher_out_channels, 1, 1)), + "std": torch.zeros((1, self.teacher_out_channels, 1, 1)), + }, + ) + + self.quantiles: nn.ParameterDict = nn.ParameterDict( + { + "qa_st": torch.tensor(0.0), + "qb_st": torch.tensor(0.0), + "qa_ae": torch.tensor(0.0), + "qb_ae": torch.tensor(0.0), + }, + ) + + def is_set(self, p_dic: nn.ParameterDict) -> bool: + """Check if any of the parameters in the parameter dictionary is set. + + Args: + p_dic (nn.ParameterDict): Parameter dictionary. + + Returns: + bool: Boolean indicating whether any of the parameters in the parameter dictionary is set. + """ + return any(value.sum() != 0 for _, value in p_dic.items()) + + def choose_random_aug_image(self, image: torch.Tensor) -> torch.Tensor: + """Choose a random augmentation function and apply it to the input image. + + Args: + image (torch.Tensor): Input image. + + Returns: + Tensor: Augmented image. + """ + transform_functions = [ + transforms.functional.adjust_brightness, + transforms.functional.adjust_contrast, + transforms.functional.adjust_saturation, + ] + # Sample an augmentation coefficient λ from the uniform distribution U(0.8, 1.2) + coefficient = np.random.default_rng().uniform(0.8, 1.2) + transform_function = np.random.default_rng().choice(transform_functions) + return transform_function(image, coefficient) + + def forward( + self, + batch: torch.Tensor, + batch_imagenet: torch.Tensor | None = None, + normalize: bool = True, + ) -> torch.Tensor | dict: + """Perform the forward-pass of the EfficientAd models. + + Args: + batch (torch.Tensor): Input images. + batch_imagenet (torch.Tensor): ImageNet batch. Defaults to None. + normalize (bool): Normalize anomaly maps or not + + Returns: + Tensor: Predictions + """ + image_size = batch.shape[-2:] + with torch.no_grad(): + teacher_output = self.teacher(batch) + if self.is_set(self.mean_std): + teacher_output = (teacher_output - self.mean_std["mean"]) / self.mean_std["std"] + + student_output = self.student(batch) + distance_st = torch.pow(teacher_output - student_output[:, : self.teacher_out_channels, :, :], 2) + + if self.training: + # Student loss + distance_st = reduce_tensor_elems(distance_st) + d_hard = torch.quantile(distance_st, 0.999) + loss_hard = torch.mean(distance_st[distance_st >= d_hard]) + student_output_penalty = self.student(batch_imagenet)[:, : self.teacher_out_channels, :, :] + loss_penalty = torch.mean(student_output_penalty**2) + loss_st = loss_hard + loss_penalty + + # Autoencoder and Student AE Loss + aug_img = self.choose_random_aug_image(batch) + ae_output_aug = self.ae(aug_img, image_size) + + with torch.no_grad(): + teacher_output_aug = self.teacher(aug_img) + if self.is_set(self.mean_std): + teacher_output_aug = (teacher_output_aug - self.mean_std["mean"]) / self.mean_std["std"] + + student_output_ae_aug = self.student(aug_img)[:, self.teacher_out_channels :, :, :] + + distance_ae = torch.pow(teacher_output_aug - ae_output_aug, 2) + distance_stae = torch.pow(ae_output_aug - student_output_ae_aug, 2) + + loss_ae = torch.mean(distance_ae) + loss_stae = torch.mean(distance_stae) + return (loss_st, loss_ae, loss_stae) + + # Eval mode. + with torch.no_grad(): + ae_output = self.ae(batch, image_size) + + map_st = torch.mean(distance_st, dim=1, keepdim=True) + map_stae = torch.mean( + (ae_output - student_output[:, self.teacher_out_channels :]) ** 2, + dim=1, + keepdim=True, + ) + + if self.pad_maps: + map_st = F.pad(map_st, (4, 4, 4, 4)) + map_stae = F.pad(map_stae, (4, 4, 4, 4)) + map_st = F.interpolate(map_st, size=image_size, mode="bilinear") + map_stae = F.interpolate(map_stae, size=image_size, mode="bilinear") + + if self.is_set(self.quantiles) and normalize: + map_st = 0.1 * (map_st - self.quantiles["qa_st"]) / (self.quantiles["qb_st"] - self.quantiles["qa_st"]) + map_stae = 0.1 * (map_stae - self.quantiles["qa_ae"]) / (self.quantiles["qb_ae"] - self.quantiles["qa_ae"]) + + map_combined = 0.5 * map_st + 0.5 * map_stae + return {"anomaly_map": map_combined, "map_st": map_st, "map_ae": map_stae} diff --git a/anomalib/models/image/fastflow/README.md b/anomalib/models/image/fastflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dfc7dda65f4772c75828f42b2fc256431da5369 --- /dev/null +++ b/anomalib/models/image/fastflow/README.md @@ -0,0 +1,117 @@ +# FastFlow: Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows + +This is the implementation of the [FastFlow](https://arxiv.org/abs/2111.07677) paper. This code is developed by utilizing the torch model implemented in [https://github.com/gathierry/FastFlow](https://github.com/gathierry/FastFlow). + +Model Type: Segmentation + +## Description + +FastFlow is a two-dimensional normalizing flow-based probability distribution estimator. It can be used as a plug-in module with any deep feature extractor, such as ResNet and vision transformer, for unsupervised anomaly detection and localisation. In the training phase, FastFlow learns to transform the input visual feature into a tractable distribution, and in the inference phase, it assesses the likelihood of identifying anomalies. + +## Architecture + +![FastFlow Architecture](/docs/source/images/fastflow/architecture.jpg "FastFlow Architecture") + +## Usage + +`python tools/train.py --model fastflow` + +## Benchmark + +All results gathered with seed `0`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +> **_NOTE:_** When the numbers are produced, early stopping callback (patience: 3) is used. It might be possible to achieve higher-metrics by increasing the patience. + +### Image-Level AUC + +| | ResNet-18 | Wide ResNet50 | DeiT | CaiT | +| ---------- | :-------: | :-----------: | :---: | :---: | +| Bottle | 1.000 | 1.000 | 0.905 | 0.986 | +| Cable | 0.891 | 0.962 | 0.942 | 0.839 | +| Capsule | 0.900 | 0.963 | 0.819 | 0.913 | +| Carpet | 0.979 | 0.994 | 0.999 | 1.000 | +| Grid | 0.988 | 1.000 | 0.991 | 0.979 | +| Hazelnut | 0.846 | 0.994 | 0.900 | 0.948 | +| Leather | 1.000 | 0.999 | 0.999 | 0.991 | +| Metal_nut | 0.963 | 0.995 | 0.911 | 0.963 | +| Pill | 0.916 | 0.942 | 0.910 | 0.916 | +| Screw | 0.521 | 0.839 | 0.705 | 0.791 | +| Tile | 0.967 | 1.000 | 0.993 | 0.998 | +| Toothbrush | 0.844 | 0.836 | 0.850 | 0.886 | +| Transistor | 0.938 | 0.979 | 0.993 | 0.983 | +| Wood | 0.978 | 0.992 | 0.979 | 0.989 | +| Zipper | 0.878 | 0.951 | 0.981 | 0.977 | +| Average | 0.907 | 0.963 | 0.925 | 0.944 | + +### Pixel-Level AUC + +| | ResNet-18 | Wide ResNet50 | DeiT | CaiT | +| ---------- | :-------: | :-----------: | :---: | :---: | +| Bottle | 0.983 | 0.986 | 0.991 | 0.984 | +| Cable | 0.954 | 0.972 | 0.973 | 0.981 | +| Capsule | 0.985 | 0.990 | 0.979 | 0.991 | +| Carpet | 0.983 | 0.991 | 0.991 | 0.992 | +| Grid | 0.985 | 0.992 | 0.980 | 0.979 | +| Hazelnut | 0.953 | 0.980 | 0.989 | 0.993 | +| Leather | 0.996 | 0.996 | 0.995 | 0.996 | +| Metal_nut | 0.972 | 0.988 | 0.978 | 0.973 | +| Pill | 0.972 | 0.976 | 0.985 | 0.992 | +| Screw | 0.926 | 0.966 | 0.945 | 0.979 | +| Tile | 0.944 | 0.966 | 0.951 | 0.960 | +| Toothbrush | 0.979 | 0.980 | 0.985 | 0.992 | +| Transistor | 0.964 | 0.971 | 0.949 | 0.960 | +| Wood | 0.956 | 0.941 | 0.952 | 0.954 | +| Zipper | 0.965 | 0.985 | 0.978 | 0.979 | +| Average | 0.968 | 0.979 | 0.975 | 0.980 | + +### Image F1 Score + +| | ResNet-18 | Wide ResNet50 | DeiT | CaiT | +| ---------- | :-------: | :-----------: | :---: | :---: | +| Bottle | 0.976 | 0.952 | 0.741 | 0.977 | +| Cable | 0.851 | 0.918 | 0.848 | 0.835 | +| Capsule | 0.937 | 0.952 | 0.905 | 0.928 | +| Carpet | 0.955 | 0.983 | 0.994 | 0.973 | +| Grid | 0.941 | 0.974 | 0.982 | 0.948 | +| Hazelnut | 0.852 | 0.979 | 0.828 | 0.900 | +| Leather | 0.995 | 0.974 | 0.995 | 0.963 | +| Metal_nut | 0.925 | 0.969 | 0.899 | 0.916 | +| Pill | 0.946 | 0.949 | 0.949 | 0.616 | +| Screw | 0.853 | 0.893 | 0.868 | 0.979 | +| Tile | 0.947 | 0.994 | 0.976 | 0.994 | +| Toothbrush | 0.875 | 0.870 | 0.833 | 0.833 | +| Transistor | 0.779 | 0.854 | 0.873 | 0.909 | +| Wood | 0.983 | 0.968 | 0.944 | 0.967 | +| Zipper | 0.921 | 0.975 | 0.958 | 0.933 | +| Average | 0.916 | 0.947 | 0.906 | 0.911 | + +### Pixel F1 Score + +| | ResNet-18 | Wide ResNet50 | DeiT | CaiT | +| ---------- | :-------: | :-----------: | :---: | :---: | +| Bottle | 0.670 | 0.733 | 0.753 | 0.725 | +| Cable | 0.547 | 0.564 | 0.487 | 0.608 | +| Capsule | 0.472 | 0.490 | 0.399 | 0.497 | +| Carpet | 0.573 | 0.598 | 0.586 | 0.606 | +| Grid | 0.412 | 0.481 | 0.393 | 0.410 | +| Hazelnut | 0.522 | 0.545 | 0.643 | 0.706 | +| Leather | 0.560 | 0.576 | 0.504 | 0.516 | +| Metal_nut | 0.728 | 0.754 | 0.766 | 0.737 | +| Pill | 0.589 | 0.611 | 0.709 | 0.617 | +| Screw | 0.061 | 0.660 | 0.269 | 0.370 | +| Tile | 0.569 | 0.660 | 0.655 | 0.660 | +| Toothbrush | 0.479 | 0.481 | 0.524 | 0.535 | +| Transistor | 0.558 | 0.573 | 0.527 | 0.567 | +| Wood | 0.557 | 0.488 | 0.614 | 0.572 | +| Zipper | 0.492 | 0.621 | 0.522 | 0.504 | +| Average | 0.519 | 0.589 | 0.557 | 0.575 | + +### Sample Results + +![Sample Result 1](/docs/source/images/fastflow/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/fastflow/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/fastflow/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/fastflow/__init__.py b/anomalib/models/image/fastflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7abb420e3339235fd406cef8da35278c96f9a55d --- /dev/null +++ b/anomalib/models/image/fastflow/__init__.py @@ -0,0 +1,10 @@ +"""FastFlow Algorithm Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Fastflow +from .loss import FastflowLoss +from .torch_model import FastflowModel + +__all__ = ["FastflowModel", "FastflowLoss", "Fastflow"] diff --git a/anomalib/models/image/fastflow/anomaly_map.py b/anomalib/models/image/fastflow/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..880b3e491796d85873190df1f7c291ad94b9aab7 --- /dev/null +++ b/anomalib/models/image/fastflow/anomaly_map.py @@ -0,0 +1,50 @@ +"""FastFlow Anomaly Map Generator Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from omegaconf import ListConfig +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap. + + Args: + input_size (ListConfig | tuple): Input size. + """ + + def __init__(self, input_size: ListConfig | tuple) -> None: + super().__init__() + self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size) + + def forward(self, hidden_variables: list[torch.Tensor]) -> torch.Tensor: + """Generate Anomaly Heatmap. + + This implementation generates the heatmap based on the flow maps + computed from the normalizing flow (NF) FastFlow blocks. Each block + yields a flow map, which overall is stacked and averaged to an anomaly + map. + + Args: + hidden_variables (list[torch.Tensor]): List of hidden variables from each NF FastFlow block. + + Returns: + Tensor: Anomaly Map. + """ + flow_maps: list[torch.Tensor] = [] + for hidden_variable in hidden_variables: + log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5 + prob = torch.exp(log_prob) + flow_map = F.interpolate( + input=-prob, + size=self.input_size, + mode="bilinear", + align_corners=False, + ) + flow_maps.append(flow_map) + flow_maps = torch.stack(flow_maps, dim=-1) + return torch.mean(flow_maps, dim=-1) diff --git a/anomalib/models/image/fastflow/lightning_model.py b/anomalib/models/image/fastflow/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f2df0780d31ccc636434dc3aecf5d461e84b4d --- /dev/null +++ b/anomalib/models/image/fastflow/lightning_model.py @@ -0,0 +1,130 @@ +"""FastFlow Lightning Model Implementation. + +https://arxiv.org/abs/2111.07677 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import FastflowLoss +from .torch_model import FastflowModel + + +class Fastflow(AnomalyModule): + """PL Lightning Module for the FastFlow algorithm. + + Args: + backbone (str): Backbone CNN network + Defaults to ``resnet18``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + flow_steps (int, optional): Flow steps. + Defaults to ``8``. + conv3x3_only (bool, optinoal): Use only conv3x3 in fast_flow model. + Defaults to ``False``. + hidden_ratio (float, optional): Ratio to calculate hidden var channels. + Defaults to ``1.0`. + """ + + def __init__( + self, + backbone: str = "resnet18", + pre_trained: bool = True, + flow_steps: int = 8, + conv3x3_only: bool = False, + hidden_ratio: float = 1.0, + ) -> None: + super().__init__() + + self.backbone = backbone + self.pre_trained = pre_trained + self.flow_steps = flow_steps + self.conv3x3_only = conv3x3_only + self.hidden_ratio = hidden_ratio + + self.model: FastflowModel + self.loss = FastflowLoss() + + def _setup(self) -> None: + if self.input_size is None: + msg = "Fastflow needs input size to build torch model." + raise ValueError(msg) + + self.model = FastflowModel( + input_size=self.input_size, + backbone=self.backbone, + pre_trained=self.pre_trained, + flow_steps=self.flow_steps, + conv3x3_only=self.conv3x3_only, + hidden_ratio=self.hidden_ratio, + ) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step input and return the loss. + + Args: + batch (batch: dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + STEP_OUTPUT: Dictionary containing the loss value. + """ + del args, kwargs # These variables are not used. + + hidden_variables, jacobians = self.model(batch["image"]) + loss = self.loss(hidden_variables, jacobians) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step and return the anomaly map. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + STEP_OUTPUT | None: batch dictionary containing anomaly-maps. + """ + del args, kwargs # These variables are not used. + + anomaly_maps = self.model(batch["image"]) + batch["anomaly_maps"] = anomaly_maps + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return FastFlow trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers for each decoder. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + return optim.Adam( + params=self.model.parameters(), + lr=0.001, + weight_decay=0.00001, + ) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/fastflow/loss.py b/anomalib/models/image/fastflow/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d7861905493b65bb88cee5ecb0bf059d858a61e5 --- /dev/null +++ b/anomalib/models/image/fastflow/loss.py @@ -0,0 +1,27 @@ +"""Loss function for the FastFlow Model Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + + +class FastflowLoss(nn.Module): + """FastFlow Loss.""" + + def forward(self, hidden_variables: list[torch.Tensor], jacobians: list[torch.Tensor]) -> torch.Tensor: + """Calculate the Fastflow loss. + + Args: + hidden_variables (list[torch.Tensor]): Hidden variables from the fastflow model. f: X -> Z + jacobians (list[torch.Tensor]): Log of the jacobian determinants from the fastflow model. + + Returns: + Tensor: Fastflow loss computed based on the hidden variables and the log of the Jacobians. + """ + loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable + for hidden_variable, jacobian in zip(hidden_variables, jacobians, strict=True): + loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian) + return loss diff --git a/anomalib/models/image/fastflow/torch_model.py b/anomalib/models/image/fastflow/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7796c83b29d9745f56c99cc1d428c4daf809191a --- /dev/null +++ b/anomalib/models/image/fastflow/torch_model.py @@ -0,0 +1,274 @@ +"""FastFlow Torch Model Implementation.""" + +# Original Code +# Copyright (c) 2022 @gathierry +# https://github.com/gathierry/FastFlow/. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable + +import timm +import torch +from FrEIA.framework import SequenceINN +from timm.models.cait import Cait +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +from anomalib.models.components.flow import AllInOneBlock + +from .anomaly_map import AnomalyMapGenerator + + +def subnet_conv_func(kernel_size: int, hidden_ratio: float) -> Callable: + """Subnet Convolutional Function. + + Callable class or function ``f``, called as ``f(channels_in, channels_out)`` and + should return a torch.nn.Module. + Predicts coupling coefficients :math:`s, t`. + + Args: + kernel_size (int): Kernel Size + hidden_ratio (float): Hidden ratio to compute number of hidden channels. + + Returns: + Callable: Sequential for the subnet constructor. + """ + + def subnet_conv(in_channels: int, out_channels: int) -> nn.Sequential: + hidden_channels = int(in_channels * hidden_ratio) + # NOTE: setting padding="same" in nn.Conv2d breaks the onnx export so manual padding required. + # TODO(ashwinvaidya17): Use padding="same" in nn.Conv2d once PyTorch v2.1 is released + # CVS-122671 + padding = 2 * (kernel_size // 2 - ((1 + kernel_size) % 2), kernel_size // 2) + return nn.Sequential( + nn.ZeroPad2d(padding), + nn.Conv2d(in_channels, hidden_channels, kernel_size), + nn.ReLU(), + nn.ZeroPad2d(padding), + nn.Conv2d(hidden_channels, out_channels, kernel_size), + ) + + return subnet_conv + + +def create_fast_flow_block( + input_dimensions: list[int], + conv3x3_only: bool, + hidden_ratio: float, + flow_steps: int, + clamp: float = 2.0, +) -> SequenceINN: + """Create NF Fast Flow Block. + + This is to create Normalizing Flow (NF) Fast Flow model block based on + Figure 2 and Section 3.3 in the paper. + + Args: + input_dimensions (list[int]): Input dimensions (Channel, Height, Width) + conv3x3_only (bool): Boolean whether to use conv3x3 only or conv3x3 and conv1x1. + hidden_ratio (float): Ratio for the hidden layer channels. + flow_steps (int): Flow steps. + clamp (float, optional): Clamp. + Defaults to ``2.0``. + + Returns: + SequenceINN: FastFlow Block. + """ + nodes = SequenceINN(*input_dimensions) + for i in range(flow_steps): + kernel_size = 1 if i % 2 == 1 and not conv3x3_only else 3 + nodes.append( + AllInOneBlock, + subnet_constructor=subnet_conv_func(kernel_size, hidden_ratio), + affine_clamping=clamp, + permute_soft=False, + ) + return nodes + + +class FastflowModel(nn.Module): + """FastFlow. + + Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows. + + Args: + input_size (tuple[int, int]): Model input size. + backbone (str): Backbone CNN network + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + flow_steps (int, optional): Flow steps. + Defaults to ``8``. + conv3x3_only (bool, optinoal): Use only conv3x3 in fast_flow model. + Defaults to ``False``. + hidden_ratio (float, optional): Ratio to calculate hidden var channels. + Defaults to ``1.0``. + + Raises: + ValueError: When the backbone is not supported. + """ + + def __init__( + self, + input_size: tuple[int, int], + backbone: str, + pre_trained: bool = True, + flow_steps: int = 8, + conv3x3_only: bool = False, + hidden_ratio: float = 1.0, + ) -> None: + super().__init__() + + self.input_size = input_size + + if backbone in ("cait_m48_448", "deit_base_distilled_patch16_384"): + self.feature_extractor = timm.create_model(backbone, pretrained=pre_trained) + channels = [768] + scales = [16] + elif backbone in ("resnet18", "wide_resnet50_2"): + self.feature_extractor = timm.create_model( + backbone, + pretrained=pre_trained, + features_only=True, + out_indices=[1, 2, 3], + ) + channels = self.feature_extractor.feature_info.channels() + scales = self.feature_extractor.feature_info.reduction() + + # for transformers, use their pretrained norm w/o grad + # for resnets, self.norms are trainable LayerNorm + self.norms = nn.ModuleList() + for channel, scale in zip(channels, scales, strict=True): + self.norms.append( + nn.LayerNorm( + [channel, int(input_size[0] / scale), int(input_size[1] / scale)], + elementwise_affine=True, + ), + ) + else: + msg = ( + f"Backbone {backbone} is not supported. List of available backbones are " + "[cait_m48_448, deit_base_distilled_patch16_384, resnet18, wide_resnet50_2]." + ) + raise ValueError(msg) + + for parameter in self.feature_extractor.parameters(): + parameter.requires_grad = False + + self.fast_flow_blocks = nn.ModuleList() + for channel, scale in zip(channels, scales, strict=True): + self.fast_flow_blocks.append( + create_fast_flow_block( + input_dimensions=[channel, int(input_size[0] / scale), int(input_size[1] / scale)], + conv3x3_only=conv3x3_only, + hidden_ratio=hidden_ratio, + flow_steps=flow_steps, + ), + ) + self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor] | tuple[list[torch.Tensor]]: + """Forward-Pass the input to the FastFlow Model. + + Args: + input_tensor (torch.Tensor): Input tensor. + + Returns: + Tensor | list[torch.Tensor] | tuple[list[torch.Tensor]]: During training, return + (hidden_variables, log-of-the-jacobian-determinants). + During the validation/test, return the anomaly map. + """ + return_val: torch.Tensor | list[torch.Tensor] | tuple[list[torch.Tensor]] + + self.feature_extractor.eval() + if isinstance(self.feature_extractor, VisionTransformer): + features = self._get_vit_features(input_tensor) + elif isinstance(self.feature_extractor, Cait): + features = self._get_cait_features(input_tensor) + else: + features = self._get_cnn_features(input_tensor) + + # Compute the hidden variable f: X -> Z and log-likelihood of the jacobian + # (See Section 3.3 in the paper.) + # NOTE: output variable has z, and jacobian tuple for each fast-flow blocks. + hidden_variables: list[torch.Tensor] = [] + log_jacobians: list[torch.Tensor] = [] + for fast_flow_block, feature in zip(self.fast_flow_blocks, features, strict=True): + hidden_variable, log_jacobian = fast_flow_block(feature) + hidden_variables.append(hidden_variable) + log_jacobians.append(log_jacobian) + + return_val = (hidden_variables, log_jacobians) + + if not self.training: + return_val = self.anomaly_map_generator(hidden_variables) + + return return_val + + def _get_cnn_features(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: + """Get CNN-based features. + + Args: + input_tensor (torch.Tensor): Input Tensor. + + Returns: + list[torch.Tensor]: List of features. + """ + features = self.feature_extractor(input_tensor) + return [self.norms[i](feature) for i, feature in enumerate(features)] + + def _get_cait_features(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: + """Get Class-Attention-Image-Transformers (CaiT) features. + + Args: + input_tensor (torch.Tensor): Input Tensor. + + Returns: + list[torch.Tensor]: List of features. + """ + feature = self.feature_extractor.patch_embed(input_tensor) + feature = feature + self.feature_extractor.pos_embed + feature = self.feature_extractor.pos_drop(feature) + for i in range(41): # paper Table 6. Block Index = 40 + feature = self.feature_extractor.blocks[i](feature) + batch_size, _, num_channels = feature.shape + feature = self.feature_extractor.norm(feature) + feature = feature.permute(0, 2, 1) + feature = feature.reshape(batch_size, num_channels, self.input_size[0] // 16, self.input_size[1] // 16) + return [feature] + + def _get_vit_features(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: + """Get Vision Transformers (ViT) features. + + Args: + input_tensor (torch.Tensor): Input Tensor. + + Returns: + list[torch.Tensor]: List of features. + """ + feature = self.feature_extractor.patch_embed(input_tensor) + cls_token = self.feature_extractor.cls_token.expand(feature.shape[0], -1, -1) + if self.feature_extractor.dist_token is None: + feature = torch.cat((cls_token, feature), dim=1) + else: + feature = torch.cat( + ( + cls_token, + self.feature_extractor.dist_token.expand(feature.shape[0], -1, -1), + feature, + ), + dim=1, + ) + feature = self.feature_extractor.pos_drop(feature + self.feature_extractor.pos_embed) + for i in range(8): # paper Table 6. Block Index = 7 + feature = self.feature_extractor.blocks[i](feature) + feature = self.feature_extractor.norm(feature) + feature = feature[:, 2:, :] + batch_size, _, num_channels = feature.shape + feature = feature.permute(0, 2, 1) + feature = feature.reshape(batch_size, num_channels, self.input_size[0] // 16, self.input_size[1] // 16) + return [feature] diff --git a/anomalib/models/image/ganomaly/README.md b/anomalib/models/image/ganomaly/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ffa5376f390b129b018cb334912b3e370003f90 --- /dev/null +++ b/anomalib/models/image/ganomaly/README.md @@ -0,0 +1,37 @@ +# GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training + +This is the implementation of the [GANomaly](https://arxiv.org/abs/1805.06725) paper. + +Model Type: Classification + +## Description + +GANomaly uses the conditional GAN approach to train a Generator to produce images of the normal data. This Generator consists of an encoder-decoder-encoder architecture to generate the normal images. The distance between the latent vector $z$ between the first encoder-decoder and the output vector $\hat{z}$ is minimized during training. + +The key idea here is that, during inference, when an anomalous image is passed through the first encoder the latent vector $z$ will not be able to capture the data correctly. This would leave to poor reconstruction $\hat{x}$ thus resulting in a very different $\hat{z}$. The difference between $z$ and $\hat{z}$ gives the anomaly score. + +## Architecture + +![GANomaly Architecture](/docs/source/images/ganomaly/architecture.jpg "GANomaly Architecture") + +## Usage + +`python tools/train.py --model ganomaly` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| --- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| --- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| | 0.834 | 0.864 | 0.844 | 0.852 | 0.836 | 0.863 | 0.863 | 0.760 | 0.905 | 0.777 | 0.894 | 0.916 | 0.853 | 0.833 | 0.571 | 0.881 | diff --git a/anomalib/models/image/ganomaly/__init__.py b/anomalib/models/image/ganomaly/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec872b077d392749fbc4a31532ec37de8a8ac738 --- /dev/null +++ b/anomalib/models/image/ganomaly/__init__.py @@ -0,0 +1,8 @@ +"""GANomaly Model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Ganomaly + +__all__ = ["Ganomaly"] diff --git a/anomalib/models/image/ganomaly/lightning_model.py b/anomalib/models/image/ganomaly/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3088780897e7bdf286c5444141786a975f5da831 --- /dev/null +++ b/anomalib/models/image/ganomaly/lightning_model.py @@ -0,0 +1,260 @@ +"""GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training. + +https://arxiv.org/abs/1805.06725 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import DiscriminatorLoss, GeneratorLoss +from .torch_model import GanomalyModel + +logger = logging.getLogger(__name__) + + +class Ganomaly(AnomalyModule): + """PL Lightning Module for the GANomaly Algorithm. + + Args: + batch_size (int): Batch size. + Defaults to ``32``. + n_features (int): Number of features layers in the CNNs. + Defaults to ``64``. + latent_vec_size (int): Size of autoencoder latent vector. + Defaults to ``100``. + extra_layers (int, optional): Number of extra layers for encoder/decoder. + Defaults to ``0``. + add_final_conv_layer (bool, optional): Add convolution layer at the end. + Defaults to ``True``. + wadv (int, optional): Weight for adversarial loss. + Defaults to ``1``. + wcon (int, optional): Image regeneration weight. + Defaults to ``50``. + wenc (int, optional): Latent vector encoder weight. + Defaults to ``1``. + lr (float, optional): Learning rate. + Defaults to ``0.0002``. + beta1 (float, optional): Adam beta1. + Defaults to ``0.5``. + beta2 (float, optional): Adam beta2. + Defaults to ``0.999``. + """ + + def __init__( + self, + batch_size: int = 32, + n_features: int = 64, + latent_vec_size: int = 100, + extra_layers: int = 0, + add_final_conv_layer: bool = True, + wadv: int = 1, + wcon: int = 50, + wenc: int = 1, + lr: float = 0.0002, + beta1: float = 0.5, + beta2: float = 0.999, + ) -> None: + super().__init__() + + self.n_features = n_features + self.latent_vec_size = latent_vec_size + self.extra_layers = extra_layers + self.add_final_conv_layer = add_final_conv_layer + + self.real_label = torch.ones(size=(batch_size,), dtype=torch.float32) + self.fake_label = torch.zeros(size=(batch_size,), dtype=torch.float32) + + self.min_scores: torch.Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable + self.max_scores: torch.Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable + + self.generator_loss = GeneratorLoss(wadv, wcon, wenc) + self.discriminator_loss = DiscriminatorLoss() + self.automatic_optimization = False + + # TODO(ashwinvaidya17): LR should be part of optimizer in config.yaml! + # CVS-122670 + self.learning_rate = lr + self.beta1 = beta1 + self.beta2 = beta2 + + self.model: GanomalyModel + + def _setup(self) -> None: + if self.input_size is None: + msg = "GANomaly needs input size to build torch model." + raise ValueError(msg) + + self.model = GanomalyModel( + input_size=self.input_size, + num_input_channels=3, + n_features=self.n_features, + latent_vec_size=self.latent_vec_size, + extra_layers=self.extra_layers, + add_final_conv_layer=self.add_final_conv_layer, + ) + + def _reset_min_max(self) -> None: + """Reset min_max scores.""" + self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable + self.max_scores = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable + + def configure_optimizers(self) -> list[optim.Optimizer]: + """Configure optimizers for each decoder. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + optimizer_d = optim.Adam( + self.model.discriminator.parameters(), + lr=self.learning_rate, + betas=(self.beta1, self.beta2), + ) + optimizer_g = optim.Adam( + self.model.generator.parameters(), + lr=self.learning_rate, + betas=(self.beta1, self.beta2), + ) + return [optimizer_d, optimizer_g] + + def training_step( + self, + batch: dict[str, str | torch.Tensor], + batch_idx: int, + ) -> STEP_OUTPUT: + """Perform the training step. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch containing images. + batch_idx (int): Batch index. + optimizer_idx (int): Optimizer which is being called for current training step. + + Returns: + STEP_OUTPUT: Loss + """ + del batch_idx # `batch_idx` variables is not used. + d_opt, g_opt = self.optimizers() + + # forward pass + padded, fake, latent_i, latent_o = self.model(batch["image"]) + pred_real, _ = self.model.discriminator(padded) + + # generator update + pred_fake, _ = self.model.discriminator(fake) + g_loss = self.generator_loss(latent_i, latent_o, padded, fake, pred_real, pred_fake) + + g_opt.zero_grad() + self.manual_backward(g_loss, retain_graph=True) + g_opt.step() + + # discrimator update + pred_fake, _ = self.model.discriminator(fake.detach()) + d_loss = self.discriminator_loss(pred_real, pred_fake) + + d_opt.zero_grad() + self.manual_backward(d_loss) + d_opt.step() + + self.log_dict( + {"generator_loss": g_loss.item(), "discriminator_loss": d_loss.item()}, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return {"generator_loss": g_loss, "discriminator_loss": d_loss} + + def on_validation_start(self) -> None: + """Reset min and max values for current validation epoch.""" + self._reset_min_max() + return super().on_validation_start() + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Update min and max scores from the current step. + + Args: + batch (dict[str, str | torch.Tensor]): Predicted difference between z and z_hat. + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + (STEP_OUTPUT): Output predictions. + """ + del args, kwargs # Unused arguments. + + batch["pred_scores"] = self.model(batch["image"]) + self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"])) + self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"])) + return batch + + def on_validation_batch_end( + self, + outputs: STEP_OUTPUT, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Normalize outputs based on min/max values.""" + outputs["pred_scores"] = self._normalize(outputs["pred_scores"]) + super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=dataloader_idx) + + def on_test_start(self) -> None: + """Reset min max values before test batch starts.""" + self._reset_min_max() + return super().on_test_start() + + def test_step(self, batch: dict[str, str | torch.Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT: + """Update min and max scores from the current step.""" + del args, kwargs # Unused arguments. + + super().test_step(batch, batch_idx) + self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"])) + self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"])) + return batch + + def on_test_batch_end( + self, + outputs: STEP_OUTPUT, + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Normalize outputs based on min/max values.""" + outputs["pred_scores"] = self._normalize(outputs["pred_scores"]) + super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=dataloader_idx) + + def _normalize(self, scores: torch.Tensor) -> torch.Tensor: + """Normalize the scores based on min/max of entire dataset. + + Args: + scores (torch.Tensor): Un-normalized scores. + + Returns: + Tensor: Normalized scores. + """ + return (scores - self.min_scores.to(scores.device)) / ( + self.max_scores.to(scores.device) - self.min_scores.to(scores.device) + ) + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return GANomaly trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/ganomaly/loss.py b/anomalib/models/image/ganomaly/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bd648ab97abf062432122b73abd2955050eb7e9e --- /dev/null +++ b/anomalib/models/image/ganomaly/loss.py @@ -0,0 +1,89 @@ +"""Loss function for the GANomaly Model Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + + +class GeneratorLoss(nn.Module): + """Generator loss for the GANomaly model. + + Args: + wadv (int, optional): Weight for adversarial loss. + Defaults to ``1``. + wcon (int, optional): Image regeneration weight. + Defaults to ``50``. + wenc (int, optional): Latent vector encoder weight. + Defaults to ``1``. + """ + + def __init__(self, wadv: int = 1, wcon: int = 50, wenc: int = 1) -> None: + super().__init__() + + self.loss_enc = nn.SmoothL1Loss() + self.loss_adv = nn.MSELoss() + self.loss_con = nn.L1Loss() + + self.wadv = wadv + self.wcon = wcon + self.wenc = wenc + + def forward( + self, + latent_i: torch.Tensor, + latent_o: torch.Tensor, + images: torch.Tensor, + fake: torch.Tensor, + pred_real: torch.Tensor, + pred_fake: torch.Tensor, + ) -> torch.Tensor: + """Compute the loss for a batch. + + Args: + latent_i (torch.Tensor): Latent features of the first encoder. + latent_o (torch.Tensor): Latent features of the second encoder. + images (torch.Tensor): Real image that served as input of the generator. + fake (torch.Tensor): Generated image. + pred_real (torch.Tensor): Discriminator predictions for the real image. + pred_fake (torch.Tensor): Discriminator predictions for the fake image. + + Returns: + Tensor: The computed generator loss. + """ + error_enc = self.loss_enc(latent_i, latent_o) + error_con = self.loss_con(images, fake) + error_adv = self.loss_adv(pred_real, pred_fake) + + return error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc + + +class DiscriminatorLoss(nn.Module): + """Discriminator loss for the GANomaly model.""" + + def __init__(self) -> None: + super().__init__() + + self.loss_bce = nn.BCELoss() + + def forward(self, pred_real: torch.Tensor, pred_fake: torch.Tensor) -> torch.Tensor: + """Compute the loss for a predicted batch. + + Args: + pred_real (torch.Tensor): Discriminator predictions for the real image. + pred_fake (torch.Tensor): Discriminator predictions for the fake image. + + Returns: + Tensor: The computed discriminator loss. + """ + error_discriminator_real = self.loss_bce( + pred_real, + torch.ones(size=pred_real.shape, dtype=torch.float32, device=pred_real.device), + ) + error_discriminator_fake = self.loss_bce( + pred_fake, + torch.zeros(size=pred_fake.shape, dtype=torch.float32, device=pred_fake.device), + ) + return (error_discriminator_fake + error_discriminator_real) * 0.5 diff --git a/anomalib/models/image/ganomaly/torch_model.py b/anomalib/models/image/ganomaly/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..320faf3d5bebeb6ce2529bf1e9415fa3b12d776c --- /dev/null +++ b/anomalib/models/image/ganomaly/torch_model.py @@ -0,0 +1,369 @@ +"""Torch models defining encoder, decoder, Generator and Discriminator. + +Code adapted from https://github.com/samet-akcay/ganomaly. +""" + +# Copyright (c) 2018-2022 Samet Akcay, Durham University, UK +# SPDX-License-Identifier: MIT +# +# Copyright (C) 2020-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import math + +import torch +from torch import nn + +from anomalib.data.utils.image import pad_nextpow2 + + +class Encoder(nn.Module): + """Encoder Network. + + Args: + input_size (tuple[int, int]): Size of input image + latent_vec_size (int): Size of latent vector z + num_input_channels (int): Number of input channels in the image + n_features (int): Number of features per convolution layer + extra_layers (int): Number of extra layers since the network uses only a single encoder layer by default. + Defaults to ``0``. + add_final_conv_layer (bool): Add a final convolution layer in the encoder. + Defaults to ``True``. + """ + + def __init__( + self, + input_size: tuple[int, int], + latent_vec_size: int, + num_input_channels: int, + n_features: int, + extra_layers: int = 0, + add_final_conv_layer: bool = True, + ) -> None: + super().__init__() + + self.input_layers = nn.Sequential() + self.input_layers.add_module( + f"initial-conv-{num_input_channels}-{n_features}", + nn.Conv2d(num_input_channels, n_features, kernel_size=4, stride=2, padding=4, bias=False), + ) + self.input_layers.add_module(f"initial-relu-{n_features}", nn.LeakyReLU(0.2, inplace=True)) + + # Extra Layers + self.extra_layers = nn.Sequential() + + for layer in range(extra_layers): + self.extra_layers.add_module( + f"extra-layers-{layer}-{n_features}-conv", + nn.Conv2d(n_features, n_features, kernel_size=3, stride=1, padding=1, bias=False), + ) + self.extra_layers.add_module(f"extra-layers-{layer}-{n_features}-batchnorm", nn.BatchNorm2d(n_features)) + self.extra_layers.add_module(f"extra-layers-{layer}-{n_features}-relu", nn.LeakyReLU(0.2, inplace=True)) + + # Create pyramid features to reach latent vector + self.pyramid_features = nn.Sequential() + pyramid_dim = min(*input_size) // 2 # Use the smaller dimension to create pyramid. + while pyramid_dim > 4: + in_features = n_features + out_features = n_features * 2 + self.pyramid_features.add_module( + f"pyramid-{in_features}-{out_features}-conv", + nn.Conv2d(in_features, out_features, kernel_size=4, stride=2, padding=1, bias=False), + ) + self.pyramid_features.add_module(f"pyramid-{out_features}-batchnorm", nn.BatchNorm2d(out_features)) + self.pyramid_features.add_module(f"pyramid-{out_features}-relu", nn.LeakyReLU(0.2, inplace=True)) + n_features = out_features + pyramid_dim = pyramid_dim // 2 + + # Final conv + if add_final_conv_layer: + self.final_conv_layer = nn.Conv2d( + n_features, + latent_vec_size, + kernel_size=4, + stride=1, + padding=0, + bias=False, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Return latent vectors.""" + output = self.input_layers(input_tensor) + output = self.extra_layers(output) + output = self.pyramid_features(output) + if self.final_conv_layer is not None: + output = self.final_conv_layer(output) + + return output + + +class Decoder(nn.Module): + """Decoder Network. + + Args: + input_size (tuple[int, int]): Size of input image + latent_vec_size (int): Size of latent vector z + num_input_channels (int): Number of input channels in the image + n_features (int): Number of features per convolution layer + extra_layers (int): Number of extra layers since the network uses only a single encoder layer by default. + Defaults to ``0``. + """ + + def __init__( + self, + input_size: tuple[int, int], + latent_vec_size: int, + num_input_channels: int, + n_features: int, + extra_layers: int = 0, + ) -> None: + super().__init__() + + self.latent_input = nn.Sequential() + + # Calculate input channel size to recreate inverse pyramid + exp_factor = math.ceil(math.log(min(input_size) // 2, 2)) - 2 + n_input_features = n_features * (2**exp_factor) + + # CNN layer for latent vector input + self.latent_input.add_module( + f"initial-{latent_vec_size}-{n_input_features}-convt", + nn.ConvTranspose2d( + latent_vec_size, + n_input_features, + kernel_size=4, + stride=1, + padding=0, + bias=False, + ), + ) + self.latent_input.add_module(f"initial-{n_input_features}-batchnorm", nn.BatchNorm2d(n_input_features)) + self.latent_input.add_module(f"initial-{n_input_features}-relu", nn.ReLU(inplace=True)) + + # Create inverse pyramid + self.inverse_pyramid = nn.Sequential() + pyramid_dim = min(*input_size) // 2 # Use the smaller dimension to create pyramid. + while pyramid_dim > 4: + in_features = n_input_features + out_features = n_input_features // 2 + self.inverse_pyramid.add_module( + f"pyramid-{in_features}-{out_features}-convt", + nn.ConvTranspose2d( + in_features, + out_features, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + ) + self.inverse_pyramid.add_module(f"pyramid-{out_features}-batchnorm", nn.BatchNorm2d(out_features)) + self.inverse_pyramid.add_module(f"pyramid-{out_features}-relu", nn.ReLU(inplace=True)) + n_input_features = out_features + pyramid_dim = pyramid_dim // 2 + + # Extra Layers + self.extra_layers = nn.Sequential() + for layer in range(extra_layers): + self.extra_layers.add_module( + f"extra-layers-{layer}-{n_input_features}-conv", + nn.Conv2d(n_input_features, n_input_features, kernel_size=3, stride=1, padding=1, bias=False), + ) + self.extra_layers.add_module( + f"extra-layers-{layer}-{n_input_features}-batchnorm", + nn.BatchNorm2d(n_input_features), + ) + self.extra_layers.add_module( + f"extra-layers-{layer}-{n_input_features}-relu", + nn.LeakyReLU(0.2, inplace=True), + ) + + # Final layers + self.final_layers = nn.Sequential() + self.final_layers.add_module( + f"final-{n_input_features}-{num_input_channels}-convt", + nn.ConvTranspose2d( + n_input_features, + num_input_channels, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + ) + self.final_layers.add_module(f"final-{num_input_channels}-tanh", nn.Tanh()) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Return generated image.""" + output = self.latent_input(input_tensor) + output = self.inverse_pyramid(output) + output = self.extra_layers(output) + return self.final_layers(output) + + +class Discriminator(nn.Module): + """Discriminator. + + Made of only one encoder layer which takes x and x_hat to produce a score. + + Args: + input_size (tuple[int, int]): Input image size. + num_input_channels (int): Number of image channels. + n_features (int): Number of feature maps in each convolution layer. + extra_layers (int, optional): Add extra intermediate layers. + Defaults to ``0``. + """ + + def __init__( + self, + input_size: tuple[int, int], + num_input_channels: int, + n_features: int, + extra_layers: int = 0, + ) -> None: + super().__init__() + encoder = Encoder(input_size, 1, num_input_channels, n_features, extra_layers) + layers = [] + for block in encoder.children(): + if isinstance(block, nn.Sequential): + layers.extend(list(block.children())) + else: + layers.append(block) + + self.features = nn.Sequential(*layers[:-1]) + self.classifier = nn.Sequential(layers[-1]) + self.classifier.add_module("Sigmoid", nn.Sigmoid()) + + def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Return class of object and features.""" + features = self.features(input_tensor) + classifier = self.classifier(features) + classifier = classifier.view(-1, 1).squeeze(1) + return classifier, features + + +class Generator(nn.Module): + """Generator model. + + Made of an encoder-decoder-encoder architecture. + + Args: + input_size (tuple[int, int]): Size of input data. + latent_vec_size (int): Dimension of latent vector produced between the first encoder-decoder. + num_input_channels (int): Number of channels in input image. + n_features (int): Number of feature maps in each convolution layer. + extra_layers (int, optional): Extra intermediate layers in the encoder/decoder. + Defaults to ``0``. + add_final_conv_layer (bool, optional): Add a final convolution layer in the decoder. + Defaults to ``True``. + """ + + def __init__( + self, + input_size: tuple[int, int], + latent_vec_size: int, + num_input_channels: int, + n_features: int, + extra_layers: int = 0, + add_final_conv_layer: bool = True, + ) -> None: + super().__init__() + self.encoder1 = Encoder( + input_size, + latent_vec_size, + num_input_channels, + n_features, + extra_layers, + add_final_conv_layer, + ) + self.decoder = Decoder(input_size, latent_vec_size, num_input_channels, n_features, extra_layers) + self.encoder2 = Encoder( + input_size, + latent_vec_size, + num_input_channels, + n_features, + extra_layers, + add_final_conv_layer, + ) + + def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return generated image and the latent vectors.""" + latent_i = self.encoder1(input_tensor) + gen_image = self.decoder(latent_i) + latent_o = self.encoder2(gen_image) + return gen_image, latent_i, latent_o + + +class GanomalyModel(nn.Module): + """Ganomaly Model. + + Args: + input_size (tuple[int, int]): Input dimension. + num_input_channels (int): Number of input channels. + n_features (int): Number of features layers in the CNNs. + latent_vec_size (int): Size of autoencoder latent vector. + extra_layers (int, optional): Number of extra layers for encoder/decoder. + Defaults to ``0``. + add_final_conv_layer (bool, optional): Add convolution layer at the end. + Defaults to ``True``. + """ + + def __init__( + self, + input_size: tuple[int, int], + num_input_channels: int, + n_features: int, + latent_vec_size: int, + extra_layers: int = 0, + add_final_conv_layer: bool = True, + ) -> None: + super().__init__() + self.generator: Generator = Generator( + input_size=input_size, + latent_vec_size=latent_vec_size, + num_input_channels=num_input_channels, + n_features=n_features, + extra_layers=extra_layers, + add_final_conv_layer=add_final_conv_layer, + ) + self.discriminator: Discriminator = Discriminator( + input_size=input_size, + num_input_channels=num_input_channels, + n_features=n_features, + extra_layers=extra_layers, + ) + self.weights_init(self.generator) + self.weights_init(self.discriminator) + + @staticmethod + def weights_init(module: nn.Module) -> None: + """Initialize DCGAN weights. + + Args: + module (nn.Module): [description] + """ + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(module.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0) + + def forward( + self, + batch: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor: + """Get scores for batch. + + Args: + batch (torch.Tensor): Images + + Returns: + Tensor: Regeneration scores. + """ + padded_batch = pad_nextpow2(batch) + fake, latent_i, latent_o = self.generator(padded_batch) + if self.training: + return padded_batch, fake, latent_i, latent_o + return torch.mean(torch.pow((latent_i - latent_o), 2), dim=1).view(-1) # convert nx1x1 to n diff --git a/anomalib/models/image/padim/README.md b/anomalib/models/image/padim/README.md new file mode 100644 index 0000000000000000000000000000000000000000..35aaeec9952b8a3ba62a10f285a44a663ed94638 --- /dev/null +++ b/anomalib/models/image/padim/README.md @@ -0,0 +1,54 @@ +# PaDiM: A Patch Distribution Modeling Framework for Anomaly Detection and Localization + +This is the implementation of the [PaDiM](https://arxiv.org/pdf/2011.08785.pdf) paper. + +Model Type: Segmentation + +## Description + +PaDiM is a patch based algorithm. It relies on a pre-trained CNN feature extractor. The image is broken into patches and embeddings are extracted from each patch using different layers of the feature extractors. The activation vectors from different layers are concatenated to get embedding vectors carrying information from different semantic levels and resolutions. This helps encode fine grained and global contexts. However, since the generated embedding vectors may carry redundant information, dimensions are reduced using random selection. A multivariate gaussian distribution is generated for each patch embedding across the entire training batch. Thus, for each patch of the set of training images, we have a different multivariate gaussian distribution. These gaussian distributions are represented as a matrix of gaussian parameters. + +During inference, Mahalanobis distance is used to score each patch position of the test image. It uses the inverse of the covariance matrix calculated for the patch during training. The matrix of Mahalanobis distances forms the anomaly map with higher scores indicating anomalous regions. + +## Architecture + +![PaDiM Architecture](/docs/source/images/padim/architecture.jpg "PaDiM Architecture") + +## Usage + +`python tools/train.py --model padim` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.891 | 0.945 | 0.857 | 0.982 | 0.950 | 0.976 | 0.994 | 0.844 | 0.901 | 0.750 | 0.961 | 0.863 | 0.759 | 0.889 | 0.920 | 0.780 | +| Wide ResNet-50 | 0.950 | 0.995 | 0.942 | 1.0 | 0.974 | 0.993 | 0.999 | 0.878 | 0.927 | 0.964 | 0.989 | 0.939 | 0.845 | 0.942 | 0.976 | 0.882 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.968 | 0.984 | 0.918 | 0.994 | 0.934 | 0.947 | 0.983 | 0.965 | 0.984 | 0.978 | 0.970 | 0.957 | 0.978 | 0.988 | 0.968 | 0.979 | +| Wide ResNet-50 | 0.979 | 0.991 | 0.970 | 0.993 | 0.955 | 0.957 | 0.985 | 0.970 | 0.988 | 0.985 | 0.982 | 0.966 | 0.988 | 0.991 | 0.976 | 0.986 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.916 | 0.930 | 0.893 | 0.984 | 0.934 | 0.952 | 0.976 | 0.858 | 0.960 | 0.836 | 0.974 | 0.932 | 0.879 | 0.923 | 0.796 | 0.915 | +| Wide ResNet-50 | 0.951 | 0.989 | 0.930 | 1.0 | 0.960 | 0.983 | 0.992 | 0.856 | 0.982 | 0.937 | 0.978 | 0.946 | 0.895 | 0.952 | 0.914 | 0.947 | + +### Sample Results + +![Sample Result 1](/docs/source/images/padim/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/padim/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/padim/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/padim/__init__.py b/anomalib/models/image/padim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..944e8f20c3376297d2487b4808f573d0dee7dc10 --- /dev/null +++ b/anomalib/models/image/padim/__init__.py @@ -0,0 +1,8 @@ +"""PADIM model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Padim + +__all__ = ["Padim"] diff --git a/anomalib/models/image/padim/anomaly_map.py b/anomalib/models/image/padim/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd8c890b59f3e5afab3c3959df24110af5d692 --- /dev/null +++ b/anomalib/models/image/padim/anomaly_map.py @@ -0,0 +1,133 @@ +"""Anomaly Map Generator for the PaDiM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import GaussianBlur2d + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap. + + Args: + image_size (ListConfig, tuple): Size of the input image. The anomaly map is upsampled to this dimension. + sigma (int, optional): Standard deviation for Gaussian Kernel. + Defaults to ``4``. + """ + + def __init__(self, sigma: int = 4) -> None: + super().__init__() + kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 + self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1) + + @staticmethod + def compute_distance(embedding: torch.Tensor, stats: list[torch.Tensor]) -> torch.Tensor: + """Compute anomaly score to the patch in position(i,j) of a test image. + + Ref: Equation (2), Section III-C of the paper. + + Args: + embedding (torch.Tensor): Embedding Vector + stats (list[torch.Tensor]): Mean and Covariance Matrix of the multivariate Gaussian distribution + + Returns: + Anomaly score of a test image via mahalanobis distance. + """ + batch, channel, height, width = embedding.shape + embedding = embedding.reshape(batch, channel, height * width) + + # calculate mahalanobis distances + mean, inv_covariance = stats + delta = (embedding - mean).permute(2, 0, 1) + + distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0) + distances = distances.reshape(batch, 1, height, width) + return distances.clamp(0).sqrt() + + def up_sample(self, distance: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor: + """Up sample anomaly score to match the input image size. + + Args: + distance (torch.Tensor): Anomaly score computed via the mahalanobis distance. + image_size (tuple[int, int] | torch.Size): Size to which the anomaly map should be upsampled. + + Returns: + Resized distance matrix matching the input image size + """ + return F.interpolate( + distance, + size=image_size, + mode="bilinear", + align_corners=False, + ) + + def smooth_anomaly_map(self, anomaly_map: torch.Tensor) -> torch.Tensor: + """Apply gaussian smoothing to the anomaly map. + + Args: + anomaly_map (torch.Tensor): Anomaly score for the test image(s). + + Returns: + Filtered anomaly scores + """ + return self.blur(anomaly_map) + + def compute_anomaly_map( + self, + embedding: torch.Tensor, + mean: torch.Tensor, + inv_covariance: torch.Tensor, + image_size: tuple[int, int] | torch.Size | None = None, + ) -> torch.Tensor: + """Compute anomaly score. + + Scores are calculated based on embedding vector, mean and inv_covariance of the multivariate gaussian + distribution. + + Args: + embedding (torch.Tensor): Embedding vector extracted from the test set. + mean (torch.Tensor): Mean of the multivariate gaussian distribution + inv_covariance (torch.Tensor): Inverse Covariance matrix of the multivariate gaussian distribution. + image_size (tuple[int, int] | torch.Size, optional): Size to which the anomaly map should be upsampled. + + Returns: + Output anomaly score. + """ + score_map = self.compute_distance( + embedding=embedding, + stats=[mean.to(embedding.device), inv_covariance.to(embedding.device)], + ) + if image_size: + score_map = self.up_sample(score_map, image_size) + return self.smooth_anomaly_map(score_map) + + def forward(self, **kwargs) -> torch.Tensor: + """Return anomaly_map. + + Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly. + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) + >>> output = anomaly_map_generator(embedding=embedding, mean=mean, covariance=covariance) + + Raises: + ValueError: `embedding`. `mean` or `covariance` keys are not found + + Returns: + torch.Tensor: anomaly map + """ + if not ("embedding" in kwargs and "mean" in kwargs and "inv_covariance" in kwargs): + msg = f"Expected keys `embedding`, `mean` and `covariance`. Found {kwargs.keys()}" + raise ValueError(msg) + + embedding: torch.Tensor = kwargs["embedding"] + mean: torch.Tensor = kwargs["mean"] + inv_covariance: torch.Tensor = kwargs["inv_covariance"] + image_size: tuple[int, int] | torch.Size = kwargs.get("image_size", None) + + return self.compute_anomaly_map(embedding, mean, inv_covariance, image_size=image_size) diff --git a/anomalib/models/image/padim/lightning_model.py b/anomalib/models/image/padim/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4912553291258b75e684e683d67ead57e19049cd --- /dev/null +++ b/anomalib/models/image/padim/lightning_model.py @@ -0,0 +1,133 @@ +"""PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + +Paper https://arxiv.org/abs/2011.08785 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin + +from .torch_model import PadimModel + +logger = logging.getLogger(__name__) + +__all__ = ["Padim"] + + +class Padim(MemoryBankMixin, AnomalyModule): + """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + + Args: + backbone (str): Backbone CNN network + Defaults to ``resnet18``. + layers (list[str]): Layers to extract features from the backbone CNN + Defaults to ``["layer1", "layer2", "layer3"]``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + n_features (int, optional): Number of features to retain in the dimension reduction step. + Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). + Defaults to ``None``. + """ + + def __init__( + self, + backbone: str = "resnet18", + layers: list[str] = ["layer1", "layer2", "layer3"], # noqa: B006 + pre_trained: bool = True, + n_features: int | None = None, + ) -> None: + super().__init__() + + self.model: PadimModel = PadimModel( + backbone=backbone, + pre_trained=pre_trained, + layers=layers, + n_features=n_features, + ) + + self.stats: list[torch.Tensor] = [] + self.embeddings: list[torch.Tensor] = [] + + @staticmethod + def configure_optimizers() -> None: + """PADIM doesn't require optimization, therefore returns no optimizers.""" + return + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + """Perform the training step of PADIM. For each batch, hierarchical features are extracted from the CNN. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Hierarchical feature map + """ + del args, kwargs # These variables are not used. + + embedding = self.model(batch["image"]) + self.embeddings.append(embedding.cpu()) + + def fit(self) -> None: + """Fit a Gaussian to the embedding collected from the training set.""" + logger.info("Aggregating the embedding extracted from the training set.") + embeddings = torch.vstack(self.embeddings) + + logger.info("Fitting a Gaussian to the embedding collected from the training set.") + self.stats = self.model.gaussian.fit(embeddings) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation step of PADIM. + + Similar to the training step, hierarchical features are extracted from the CNN for each batch. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing images, features, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, int | float]: + """Return PADIM trainer arguments. + + Since the model does not require training, we limit the max_epochs to 1. + Since we need to run training epoch before validation, we also set the sanity steps to 0 + """ + return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Default transform for Padim.""" + image_size = image_size or (256, 256) + return Compose( + [ + Resize(image_size, antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + ) diff --git a/anomalib/models/image/padim/torch_model.py b/anomalib/models/image/padim/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f45dde1f79c6b3ae68605fb7cf22a780b3e305f4 --- /dev/null +++ b/anomalib/models/image/padim/torch_model.py @@ -0,0 +1,168 @@ +"""PyTorch model for the PaDiM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from random import sample +from typing import TYPE_CHECKING + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import MultiVariateGaussian, TimmFeatureExtractor +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims + +from .anomaly_map import AnomalyMapGenerator + +if TYPE_CHECKING: + from anomalib.data.utils.tiler import Tiler + +# defaults from the paper +_N_FEATURES_DEFAULTS = { + "resnet18": 100, + "wide_resnet50_2": 550, +} + + +def _deduce_dims( + feature_extractor: TimmFeatureExtractor, + input_size: tuple[int, int], + layers: list[str], +) -> tuple[int, int]: + """Run a dry run to deduce the dimensions of the extracted features. + + Important: `layers` is assumed to be ordered and the first (layers[0]) + is assumed to be the layer with largest resolution. + + Returns: + tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) + """ + dimensions_mapping = dryrun_find_featuremap_dims(feature_extractor, input_size, layers) + + # the first layer in `layers` has the largest resolution + first_layer_resolution = dimensions_mapping[layers[0]]["resolution"] + n_patches = torch.tensor(first_layer_resolution).prod().int().item() + + # the original embedding size is the sum of the channels of all layers + n_features_original = sum(dimensions_mapping[layer]["num_features"] for layer in layers) # type: ignore[misc] + + return n_features_original, n_patches + + +class PadimModel(nn.Module): + """Padim Module. + + Args: + layers (list[str]): Layers used for feature extraction + backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". + Defaults to ``resnet18``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + n_features (int, optional): Number of features to retain in the dimension reduction step. + Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). + Defaults to ``None``. + """ + + def __init__( + self, + backbone: str = "resnet18", + layers: list[str] = ["layer1", "layer2", "layer3"], # noqa: B006 + pre_trained: bool = True, + n_features: int | None = None, + ) -> None: + super().__init__() + self.tiler: Tiler | None = None + + self.backbone = backbone + self.layers = layers + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + layers=layers, + pre_trained=pre_trained, + ).eval() + self.n_features_original = sum(self.feature_extractor.out_dims) + self.n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone) + if self.n_features is None: + msg = ( + f"n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) + raise ValueError(msg) + + if not (0 < self.n_features <= self.n_features_original): + msg = f"For backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {self.n_features}" + raise ValueError(msg) + + # Since idx is randomly selected, save it with model to get same results + self.register_buffer( + "idx", + torch.tensor(sample(range(self.n_features_original), self.n_features)), + ) + self.idx: torch.Tensor + self.loss = None + self.anomaly_map_generator = AnomalyMapGenerator() + + self.gaussian = MultiVariateGaussian() + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Forward-pass image-batch (N, C, H, W) into model to extract features. + + Args: + input_tensor: Image-batch (N, C, H, W) + input_tensor: torch.Tensor: + + Returns: + Features from single/multiple layers. + + Example: + >>> x = torch.randn(32, 3, 224, 224) + >>> features = self.extract_features(input_tensor) + >>> features.keys() + dict_keys(['layer1', 'layer2', 'layer3']) + + >>> [v.shape for v in features.values()] + [torch.Size([32, 64, 56, 56]), + torch.Size([32, 128, 28, 28]), + torch.Size([32, 256, 14, 14])] + """ + output_size = input_tensor.shape[-2:] + if self.tiler: + input_tensor = self.tiler.tile(input_tensor) + + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + embeddings = self.generate_embedding(features) + + if self.tiler: + embeddings = self.tiler.untile(embeddings) + + if self.training: + output = embeddings + else: + output = self.anomaly_map_generator( + embedding=embeddings, + mean=self.gaussian.mean, + inv_covariance=self.gaussian.inv_covariance, + image_size=output_size, + ) + return output + + def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: + """Generate embedding from hierarchical feature map. + + Args: + features (dict[str, torch.Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) + + Returns: + Embedding vector + """ + embeddings = features[self.layers[0]] + for layer in self.layers[1:]: + layer_embedding = features[layer] + layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") + embeddings = torch.cat((embeddings, layer_embedding), 1) + + # subsample embeddings + idx = self.idx.to(embeddings.device) + return torch.index_select(embeddings, 1, idx) diff --git a/anomalib/models/image/patchcore/README.md b/anomalib/models/image/patchcore/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9d49ffd6f84797b9b324dca52206ed0be3cb171d --- /dev/null +++ b/anomalib/models/image/patchcore/README.md @@ -0,0 +1,54 @@ +# PatchCore + +This is the implementation of the [PatchCore](https://arxiv.org/pdf/2106.08265.pdf) paper. + +Model Type: Segmentation + +## Description + +The PatchCore algorithm is based on the idea that an image can be classified as anomalous as soon as a single patch is anomalous. The input image is tiled. These tiles act as patches which are fed into the neural network. It consists of a single pre-trained network which is used to extract "mid" level features patches. The "mid" level here refers to the feature extraction layer of the neural network model. Lower level features are generally too broad and higher level features are specific to the dataset the model is trained on. The features extracted during training phase are stored in a memory bank of neighbourhood aware patch level features. + +During inference this memory bank is coreset subsampled. Coreset subsampling generates a subset which best approximates the structure of the available set and allows for approximate solution finding. This subset helps reduce the search cost associated with nearest neighbour search. The anomaly score is taken as the maximum distance between the test patch in the test patch collection to each respective nearest neighbour. + +## Architecture + +![PatchCore Architecture](/docs/source/images/patchcore/architecture.jpg "PatchCore Architecture") + +## Usage + +`python tools/train.py --model patchcore` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 | +| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 | +| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 | +| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 | + +### Sample Results + +![Sample Result 1](/docs/source/images/patchcore/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/patchcore/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/patchcore/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/patchcore/__init__.py b/anomalib/models/image/patchcore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e69fa85713dbbb32925bfd9e177c969e0521df1 --- /dev/null +++ b/anomalib/models/image/patchcore/__init__.py @@ -0,0 +1,8 @@ +"""PatchCore model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Patchcore + +__all__ = ["Patchcore"] diff --git a/anomalib/models/image/patchcore/anomaly_map.py b/anomalib/models/image/patchcore/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbd2d668e0c782bef5f4db00a092fd9260946c3 --- /dev/null +++ b/anomalib/models/image/patchcore/anomaly_map.py @@ -0,0 +1,73 @@ +"""Anomaly Map Generator for the PatchCore model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import GaussianBlur2d + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap. + + Args: + The anomaly map is upsampled to this dimension. + sigma (int, optional): Standard deviation for Gaussian Kernel. + Defaults to ``4``. + """ + + def __init__( + self, + sigma: int = 4, + ) -> None: + super().__init__() + kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 + self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1) + + def compute_anomaly_map( + self, + patch_scores: torch.Tensor, + image_size: tuple[int, int] | torch.Size | None = None, + ) -> torch.Tensor: + """Pixel Level Anomaly Heatmap. + + Args: + patch_scores (torch.Tensor): Patch-level anomaly scores + image_size (tuple[int, int] | torch.Size, optional): Size of the input image. + The anomaly map is upsampled to this dimension. + Defaults to None. + + Returns: + Tensor: Map of the pixel-level anomaly scores + """ + if image_size is None: + anomaly_map = patch_scores + else: + anomaly_map = F.interpolate(patch_scores, size=(image_size[0], image_size[1])) + return self.blur(anomaly_map) + + def forward( + self, + patch_scores: torch.Tensor, + image_size: tuple[int, int] | torch.Size | None = None, + ) -> torch.Tensor: + """Return anomaly_map and anomaly_score. + + Args: + patch_scores (torch.Tensor): Patch-level anomaly scores + image_size (tuple[int, int] | torch.Size, optional): Size of the input image. + The anomaly map is upsampled to this dimension. + Defaults to None. + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator() + >>> map = anomaly_map_generator(patch_scores=patch_scores) + + Returns: + Tensor: anomaly_map + """ + return self.compute_anomaly_map(patch_scores, image_size) diff --git a/anomalib/models/image/patchcore/lightning_model.py b/anomalib/models/image/patchcore/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0b2081d4cea54664338b5f8cae6825c2e36102 --- /dev/null +++ b/anomalib/models/image/patchcore/lightning_model.py @@ -0,0 +1,141 @@ +"""Towards Total Recall in Industrial Anomaly Detection. + +Paper https://arxiv.org/abs/2106.08265. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Sequence +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize, Transform + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin + +from .torch_model import PatchcoreModel + +logger = logging.getLogger(__name__) + + +class Patchcore(MemoryBankMixin, AnomalyModule): + """PatchcoreLightning Module to train PatchCore algorithm. + + Args: + backbone (str): Backbone CNN network + Defaults to ``wide_resnet50_2``. + layers (list[str]): Layers to extract features from the backbone CNN + Defaults to ``["layer2", "layer3"]``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + coreset_sampling_ratio (float, optional): Coreset sampling ratio to subsample embedding. + Defaults to ``0.1``. + num_neighbors (int, optional): Number of nearest neighbors. + Defaults to ``9``. + """ + + def __init__( + self, + backbone: str = "wide_resnet50_2", + layers: Sequence[str] = ("layer2", "layer3"), + pre_trained: bool = True, + coreset_sampling_ratio: float = 0.1, + num_neighbors: int = 9, + ) -> None: + super().__init__() + + self.model: PatchcoreModel = PatchcoreModel( + backbone=backbone, + pre_trained=pre_trained, + layers=layers, + num_neighbors=num_neighbors, + ) + self.coreset_sampling_ratio = coreset_sampling_ratio + self.embeddings: list[torch.Tensor] = [] + + def configure_optimizers(self) -> None: + """Configure optimizers. + + Returns: + None: Do not set optimizers by returning None. + """ + return + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + """Generate feature embedding of the batch. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + dict[str, np.ndarray]: Embedding Vector + """ + del args, kwargs # These variables are not used. + + embedding = self.model(batch["image"]) + self.embeddings.append(embedding) + + def fit(self) -> None: + """Apply subsampling to the embedding collected from the training set.""" + logger.info("Aggregating the embedding extracted from the training set.") + embeddings = torch.vstack(self.embeddings) + + logger.info("Applying core-set subsampling to get the embedding.") + self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Get batch of anomaly maps from input image batch. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + dict[str, Any]: Image filenames, test images, GT and predicted label/masks + """ + # These variables are not used. + del args, kwargs + + # Get anomaly maps and predicted scores from the model. + output = self.model(batch["image"]) + + # Add anomaly maps and predicted scores to the batch. + batch["anomaly_maps"] = output["anomaly_map"] + batch["pred_scores"] = output["pred_score"] + + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return Patchcore trainer arguments.""" + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Default transform for Padim.""" + image_size = image_size or (256, 256) + # scale center crop size proportional to image size + height, width = image_size + center_crop_size = (int(height * (224 / 256)), int(width * (224 / 256))) + return Compose( + [ + Resize(image_size, antialias=True), + CenterCrop(center_crop_size), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + ) diff --git a/anomalib/models/image/patchcore/torch_model.py b/anomalib/models/image/patchcore/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ceb32b91fcc019eadb17c78449d07cdd4b0137 --- /dev/null +++ b/anomalib/models/image/patchcore/torch_model.py @@ -0,0 +1,233 @@ +"""PyTorch model for the PatchCore model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor + +from .anomaly_map import AnomalyMapGenerator + +if TYPE_CHECKING: + from anomalib.data.utils.tiler import Tiler + + +class PatchcoreModel(DynamicBufferMixin, nn.Module): + """Patchcore Module. + + Args: + layers (list[str]): Layers used for feature extraction + backbone (str, optional): Pre-trained model backbone. + Defaults to ``resnet18``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + num_neighbors (int, optional): Number of nearest neighbors. + Defaults to ``9``. + """ + + def __init__( + self, + layers: Sequence[str], + backbone: str = "wide_resnet50_2", + pre_trained: bool = True, + num_neighbors: int = 9, + ) -> None: + super().__init__() + self.tiler: Tiler | None = None + + self.backbone = backbone + self.layers = layers + self.num_neighbors = num_neighbors + + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=pre_trained, + layers=self.layers, + ).eval() + self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) + self.anomaly_map_generator = AnomalyMapGenerator() + + self.register_buffer("memory_bank", torch.Tensor()) + self.memory_bank: torch.Tensor + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. + + Steps performed: + 1. Get features from a CNN. + 2. Generate embedding based on the features. + 3. Compute anomaly map in test mode. + + Args: + input_tensor (torch.Tensor): Input tensor + + Returns: + Tensor | dict[str, torch.Tensor]: Embedding for training, anomaly map and anomaly score for testing. + """ + output_size = input_tensor.shape[-2:] + if self.tiler: + input_tensor = self.tiler.tile(input_tensor) + + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + + features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} + embedding = self.generate_embedding(features) + + if self.tiler: + embedding = self.tiler.untile(embedding) + + batch_size, _, width, height = embedding.shape + embedding = self.reshape_embedding(embedding) + + if self.training: + output = embedding + else: + # apply nearest neighbor search + patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1) + # reshape to batch dimension + patch_scores = patch_scores.reshape((batch_size, -1)) + locations = locations.reshape((batch_size, -1)) + # compute anomaly score + pred_score = self.compute_anomaly_score(patch_scores, locations, embedding) + # reshape to w, h + patch_scores = patch_scores.reshape((batch_size, 1, width, height)) + # get anomaly map + anomaly_map = self.anomaly_map_generator(patch_scores, output_size) + + output = {"anomaly_map": anomaly_map, "pred_score": pred_score} + + return output + + def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: + """Generate embedding from hierarchical feature map. + + Args: + features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) + features: dict[str:Tensor]: + + Returns: + Embedding vector + """ + embeddings = features[self.layers[0]] + for layer in self.layers[1:]: + layer_embedding = features[layer] + layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="bilinear") + embeddings = torch.cat((embeddings, layer_embedding), 1) + + return embeddings + + @staticmethod + def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: + """Reshape Embedding. + + Reshapes Embedding to the following format: + - [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] + + Args: + embedding (torch.Tensor): Embedding tensor extracted from CNN features. + + Returns: + Tensor: Reshaped embedding tensor. + """ + embedding_size = embedding.size(1) + return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) + + def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: + """Subsample embedding based on coreset sampling and store to memory. + + Args: + embedding (np.ndarray): Embedding tensor from the CNN + sampling_ratio (float): Coreset sampling ratio + """ + # Coreset Subsampling + sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) + coreset = sampler.sample_coreset() + self.memory_bank = coreset + + @staticmethod + def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculate pair-wise distance between row vectors in x and those in y. + + Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format. + Resulting matrix is indexed by x vectors in rows and y vectors in columns. + + Args: + x: input tensor 1 + y: input tensor 2 + + Returns: + Matrix of distances between row vectors in x and y. + """ + x_norm = x.pow(2).sum(dim=-1, keepdim=True) # |x| + y_norm = y.pow(2).sum(dim=-1, keepdim=True) # |y| + # row distance can be rewritten as sqrt(|x| - 2 * x @ y.T + |y|.T) + res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1) + return res.clamp_min_(0).sqrt_() + + def nearest_neighbors(self, embedding: torch.Tensor, n_neighbors: int) -> tuple[torch.Tensor, torch.Tensor]: + """Nearest Neighbours using brute force method and euclidean norm. + + Args: + embedding (torch.Tensor): Features to compare the distance with the memory bank. + n_neighbors (int): Number of neighbors to look at + + Returns: + Tensor: Patch scores. + Tensor: Locations of the nearest neighbor(s). + """ + distances = self.euclidean_dist(embedding, self.memory_bank) + if n_neighbors == 1: + # when n_neighbors is 1, speed up computation by using min instead of topk + patch_scores, locations = distances.min(1) + else: + patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1) + return patch_scores, locations + + def compute_anomaly_score( + self, + patch_scores: torch.Tensor, + locations: torch.Tensor, + embedding: torch.Tensor, + ) -> torch.Tensor: + """Compute Image-Level Anomaly Score. + + Args: + patch_scores (torch.Tensor): Patch-level anomaly scores + locations: Memory bank locations of the nearest neighbor for each patch location + embedding: The feature embeddings that generated the patch scores + + Returns: + Tensor: Image-level anomaly scores + """ + # Don't need to compute weights if num_neighbors is 1 + if self.num_neighbors == 1: + return patch_scores.amax(1) + batch_size, num_patches = patch_scores.shape + # 1. Find the patch with the largest distance to it's nearest neighbor in each image + max_patches = torch.argmax(patch_scores, dim=1) # indices of m^test,* in the paper + # m^test,* in the paper + max_patches_features = embedding.reshape(batch_size, num_patches, -1)[torch.arange(batch_size), max_patches] + # 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank + score = patch_scores[torch.arange(batch_size), max_patches] # s^* in the paper + nn_index = locations[torch.arange(batch_size), max_patches] # indices of m^* in the paper + # 3. Find the support samples of the nearest neighbor in the membank + nn_sample = self.memory_bank[nn_index, :] # m^* in the paper + # indices of N_b(m^*) in the paper + memory_bank_effective_size = self.memory_bank.shape[0] # edge case when memory bank is too small + _, support_samples = self.nearest_neighbors( + nn_sample, + n_neighbors=min(self.num_neighbors, memory_bank_effective_size), + ) + # 4. Find the distance of the patch features to each of the support samples + distances = self.euclidean_dist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples]) + # 5. Apply softmax to find the weights + weights = (1 - F.softmax(distances.squeeze(1), 1))[..., 0] + # 6. Apply the weight factor to the score + return weights * score # s in the paper diff --git a/anomalib/models/image/reverse_distillation/LICENSE b/anomalib/models/image/reverse_distillation/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9ec89476404e1091be7b06b3873cbd1184fd01fe --- /dev/null +++ b/anomalib/models/image/reverse_distillation/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2022 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original Reverse Distillation implementation by hq-deng + +Original license +---------------- + + MIT License + + Copyright (c) 2022 hq-deng + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/anomalib/models/image/reverse_distillation/README.md b/anomalib/models/image/reverse_distillation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8b22685f411bf2f3e3708f0afcbd16e741665e63 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/README.md @@ -0,0 +1,51 @@ +# Anomaly Detection via Reverse Distillation from One-Class Embedding + +This is the implementation of the [Reverse Distillation](https://arxiv.org/pdf/2201.10703v2.pdf) paper. + +Model Type: Segmentation + +## Description + +Reverse Distillation model consists of three networks. The first is a pre-trained feature extractor (E). The next two are the one-class bottleneck embedding (OCBE) and the student decoder network (D). The backbone E is a ResNet model pre-trained on ImageNet dataset. During the forward pass, features from three ResNet block are extracted. These features are encoded by concatenating the three feature maps using the multi-scale feature fusion block of OCBE and passed to the decoder D. The decoder network is symmetrical to the feature extractor but reversed. During training, outputs from these symmetrical blocks are forced to be similar to the corresponding feature extractor layers by using cosine distance as the loss metric. + +During testing, a similar step is followed but this time the cosine distance between the feature maps is used to indicate the presence of anomalies. The distance maps from all the three layers are up-sampled to the image size and added (or multiplied) to produce the final feature map. Gaussian blur is applied to the output map to make it smoother. Finally, the anomaly map is generated by applying min-max normalization on the output map. + +## Architecture + +![Anomaly Detection via Reverse Distillation from One-Class Embedding Architecture](/docs/source/images/reverse_distillation/architecture.png "Reverse Distillation Architecture") + +## Usage + +`python tools/train.py --model reverse_distillation` + +## Benchmark + +All results gathered with seed `42`, train batch size `16`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.985 | 0.984 | 1.000 | 1.000 | 1.000 | 0.997 | 1.000 | 0.966 | 0.974 | 1.000 | 1.000 | 0.972 | 0.985 | 0.953 | 0.970 | 0.978 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.969 | 0.988 | 0.992 | 0.991 | 0.954 | 0.947 | 0.984 | 0.964 | 0.987 | 0.988 | 0.969 | 0.975 | 0.996 | 0.991 | 0.893 | 0.984 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.976 | 0.977 | 1.000 | 1.000 | 0.994 | 0.992 | 0.984 | 0.930 | 0.982 | 1.000 | 1.000 | 0.967 | 0.963 | 0.952 | 0.927 | 0.975 | + +### Sample Results + +![Sample Result 1](/docs/source/images/reverse_distillation/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/reverse_distillation/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/reverse_distillation/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/reverse_distillation/__init__.py b/anomalib/models/image/reverse_distillation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd60dcb25cc3d942aa379a55996c9f38d815f7a --- /dev/null +++ b/anomalib/models/image/reverse_distillation/__init__.py @@ -0,0 +1,8 @@ +"""Reverse Distillation Model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import ReverseDistillation + +__all__ = ["ReverseDistillation"] diff --git a/anomalib/models/image/reverse_distillation/anomaly_map.py b/anomalib/models/image/reverse_distillation/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..94e591cdfe4c7d2c11903cef543b31c6b59e8a57 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/anomaly_map.py @@ -0,0 +1,96 @@ +"""Compute Anomaly map.""" + +# Original Code +# Copyright (c) 2022 hq-deng +# https://github.com/hq-deng/RD4AD +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +import torch +from omegaconf import ListConfig +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import GaussianBlur2d + + +class AnomalyMapGenerationMode(str, Enum): + """Type of mode when generating anomaly imape.""" + + ADD = "add" + MULTIPLY = "multiply" + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap. + + Args: + image_size (ListConfig, tuple): Size of original image used for upscaling the anomaly map. + sigma (int): Standard deviation of the gaussian kernel used to smooth anomaly map. + Defaults to ``4``. + mode (AnomalyMapGenerationMode, optional): Operation used to generate anomaly map. + Options are ``AnomalyMapGenerationMode.ADD`` and ``AnomalyMapGenerationMode.MULTIPLY``. + Defaults to ``AnomalyMapGenerationMode.MULTIPLY``. + + Raises: + ValueError: In case modes other than multiply and add are passed. + """ + + def __init__( + self, + image_size: ListConfig | tuple, + sigma: int = 4, + mode: AnomalyMapGenerationMode = AnomalyMapGenerationMode.MULTIPLY, + ) -> None: + super().__init__() + self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) + self.sigma = sigma + self.kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 + + if mode not in (AnomalyMapGenerationMode.ADD, AnomalyMapGenerationMode.MULTIPLY): + msg = f"Found mode {mode}. Only multiply and add are supported." + raise ValueError(msg) + self.mode = mode + + def forward(self, student_features: list[torch.Tensor], teacher_features: list[torch.Tensor]) -> torch.Tensor: + """Compute anomaly map given encoder and decoder features. + + Args: + student_features (list[torch.Tensor]): List of encoder features + teacher_features (list[torch.Tensor]): List of decoder features + + Returns: + Tensor: Anomaly maps of length batch. + """ + if self.mode == AnomalyMapGenerationMode.MULTIPLY: + anomaly_map = torch.ones( + [student_features[0].shape[0], 1, *self.image_size], + device=student_features[0].device, + ) # b c h w + elif self.mode == AnomalyMapGenerationMode.ADD: + anomaly_map = torch.zeros( + [student_features[0].shape[0], 1, *self.image_size], + device=student_features[0].device, + ) + + for student_feature, teacher_feature in zip(student_features, teacher_features, strict=True): + distance_map = 1 - F.cosine_similarity(student_feature, teacher_feature) + distance_map = torch.unsqueeze(distance_map, dim=1) + distance_map = F.interpolate(distance_map, size=self.image_size, mode="bilinear", align_corners=True) + if self.mode == AnomalyMapGenerationMode.MULTIPLY: + anomaly_map *= distance_map + elif self.mode == AnomalyMapGenerationMode.ADD: + anomaly_map += distance_map + + gaussian_blur = GaussianBlur2d( + kernel_size=(self.kernel_size, self.kernel_size), + sigma=(self.sigma, self.sigma), + ).to(student_features[0].device) + + return gaussian_blur(anomaly_map) diff --git a/anomalib/models/image/reverse_distillation/components/__init__.py b/anomalib/models/image/reverse_distillation/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..631eba439a79405ff76f6b23be988dc73904d7e4 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/components/__init__.py @@ -0,0 +1,20 @@ +"""PyTorch modules for Reverse Distillation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from .bottleneck import get_bottleneck_layer +from .de_resnet import get_decoder + +__all__ = ["get_bottleneck_layer", "get_decoder"] diff --git a/anomalib/models/image/reverse_distillation/components/bottleneck.py b/anomalib/models/image/reverse_distillation/components/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f6c7df9433da21a15ca8236d2474ee398ab318 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/components/bottleneck.py @@ -0,0 +1,167 @@ +"""Torch model defining the bottleneck layer.""" + +# Original Code +# Copyright (c) 2022 hq-deng +# https://github.com/hq-deng/RD4AD +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable + +import torch +from torch import nn +from torchvision.models.resnet import BasicBlock, Bottleneck + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution.""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class OCBE(nn.Module): + """One-Class Bottleneck Embedding module. + + Args: + block (Bottleneck): Expansion value is extracted from this block. + layers (int): Numbers of OCE layers to create after multiscale feature fusion. + groups (int, optional): Number of blocked connections from input channels to output channels. + Defaults to 1. + width_per_group (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. + norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use. Defaults to None. + """ + + def __init__( + self, + block: Bottleneck | BasicBlock, + layers: int, + groups: int = 1, + width_per_group: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.groups = groups + self.base_width = width_per_group + self.inplanes = 256 * block.expansion + self.dilation = 1 + self.bn_layer = self._make_layer(block, 512, layers, stride=2) + + self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2) + self.bn1 = norm_layer(128 * block.expansion) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn2 = norm_layer(256 * block.expansion) + self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn3 = norm_layer(256 * block.expansion) + + # self.conv4 and self.bn4 are from the original code: + # https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/resnet.py#L412 + self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) + self.bn4 = norm_layer(512 * block.expansion) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.BatchNorm2d | nn.GroupNorm): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def _make_layer( + self, + block: type[Bottleneck | BasicBlock], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes * 3, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes * 3, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ), + ) + self.inplanes = planes * block.expansion + layers.extend( + [ + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + for _ in range(1, blocks) + ], + ) + + return nn.Sequential(*layers) + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Forward-pass of Bottleneck layer. + + Args: + features (list[torch.Tensor]): List of features extracted from the encoder. + + Returns: + Tensor: Output of the bottleneck layer + """ + # Always assumes that features has length of 3 + feature0 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(features[0])))))) + feature1 = self.relu(self.bn3(self.conv3(features[1]))) + feature_cat = torch.cat([feature0, feature1, features[2]], 1) + output = self.bn_layer(feature_cat) + + return output.contiguous() + + +def get_bottleneck_layer(backbone: str, **kwargs) -> OCBE: + """Get appropriate bottleneck layer based on the name of the backbone. + + Args: + backbone (str): Name of the backbone. + kwargs: Additional keyword arguments. + + Returns: + Bottleneck_layer: One-Class Bottleneck Embedding module. + """ + return OCBE(BasicBlock, 2, **kwargs) if backbone in ("resnet18", "resnet34") else OCBE(Bottleneck, 3, **kwargs) diff --git a/anomalib/models/image/reverse_distillation/components/de_resnet.py b/anomalib/models/image/reverse_distillation/components/de_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..93bda2c83e7b696cad9b121e80f0ffa55ae46270 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/components/de_resnet.py @@ -0,0 +1,355 @@ +"""Torch model defining the decoder.""" + +# Original Code +# Copyright (c) 2022 hq-deng +# https://github.com/hq-deng/RD4AD +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Callable + +import torch +from torch import nn +from torchvision.models.resnet import conv1x1, conv3x3 + + +class DecoderBasicBlock(nn.Module): + """Basic block for decoder ResNet architecture. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for convolution and de-convolution layers. Defaults to 1. + upsample (nn.Module | None, optional): Module used for upsampling output. Defaults to None. + groups (int, optional): Number of blocked connections from input channels to output channels. + Defaults to 1. + base_width (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use.Defaults to None. + + Raises: + ValueError: If groups are not equal to 1 and base width is not 64. + NotImplementedError: If dilation is greater than 1. + """ + + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + upsample: nn.Module | None = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + msg = "BasicBlock only supports groups=1 and base_width=64" + raise ValueError(msg) + if dilation > 1: + msg = "Dilation > 1 not supported in BasicBlock" + raise NotImplementedError(msg) + # Both self.conv1 and self.downsample layers downsample the input when stride != 2 + if stride == 2: + self.conv1 = nn.ConvTranspose2d( + inplanes, + planes, + kernel_size=2, + stride=stride, + groups=groups, + bias=False, + dilation=dilation, + ) + else: + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.upsample = upsample + self.stride = stride + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Forward-pass of de-resnet block.""" + identity = batch + + out = self.conv1(batch) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.upsample is not None: + identity = self.upsample(batch) + + out += identity + return self.relu(out) + + +class DecoderBottleneck(nn.Module): + """Bottleneck for Decoder. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for convolution and de-convolution layers. Defaults to 1. + upsample (nn.Module | None, optional): Module used for upsampling output. Defaults to None. + groups (int, optional): Number of blocked connections from input channels to output channels. + Defaults to 1. + base_width (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use.Defaults to None. + """ + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + upsample: nn.Module | None = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 2 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + if stride == 2: + self.conv2 = nn.ConvTranspose2d( + width, + width, + kernel_size=2, + stride=stride, + groups=groups, + bias=False, + dilation=dilation, + ) + else: + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.upsample = upsample + self.stride = stride + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Forward-pass of de-resnet bottleneck block.""" + identity = batch + + out = self.conv1(batch) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.upsample is not None: + identity = self.upsample(batch) + + out += identity + return self.relu(out) + + +class ResNet(nn.Module): + """ResNet model for decoder. + + Args: + block (Type[DecoderBasicBlock | DecoderBottleneck]): Type of block to use in a layer. + layers (list[int]): List to specify number for blocks per layer. + zero_init_residual (bool, optional): If true, initializes the last batch norm in each layer to zero. + Defaults to False. + groups (int, optional): Number of blocked connections per layer from input channels to output channels. + Defaults to 1. + width_per_group (int, optional): Number of layers in each intermediate convolution layer.. Defaults to 64. + norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use. Defaults to None. + """ + + def __init__( + self, + block: type[DecoderBasicBlock | DecoderBottleneck], + layers: list[int], + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 512 * block.expansion + self.dilation = 1 + self.groups = groups + self.base_width = width_per_group + self.layer1 = self._make_layer(block, 256, layers[0], stride=2) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 64, layers[2], stride=2) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.BatchNorm2d | nn.GroupNorm): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for module in self.modules(): + if isinstance(module, DecoderBottleneck): + nn.init.constant_(module.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(module, DecoderBasicBlock): + nn.init.constant_(module.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: type[DecoderBasicBlock | DecoderBottleneck], + planes: int, + blocks: int, + stride: int = 1, + ) -> nn.Sequential: + norm_layer = self._norm_layer + upsample = None + previous_dilation = self.dilation + if stride != 1 or self.inplanes != planes * block.expansion: + upsample = nn.Sequential( + nn.ConvTranspose2d( + self.inplanes, + planes * block.expansion, + kernel_size=2, + stride=stride, + groups=self.groups, + bias=False, + dilation=self.dilation, + ), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, upsample, self.groups, self.base_width, previous_dilation, norm_layer), + ) + self.inplanes = planes * block.expansion + layers.extend( + [ + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + for _ in range(1, blocks) + ], + ) + + return nn.Sequential(*layers) + + def forward(self, batch: torch.Tensor) -> list[torch.Tensor]: + """Forward pass for Decoder ResNet. Returns list of features.""" + feature_a = self.layer1(batch) # 512*8*8->256*16*16 + feature_b = self.layer2(feature_a) # 256*16*16->128*32*32 + feature_c = self.layer3(feature_b) # 128*32*32->64*64*64 + + return [feature_c, feature_b, feature_a] + + +def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNet: + return ResNet(block, layers, **kwargs) + + +def de_resnet18() -> ResNet: + """ResNet-18 model.""" + return _resnet(DecoderBasicBlock, [2, 2, 2, 2]) + + +def de_resnet34() -> ResNet: + """ResNet-34 model.""" + return _resnet(DecoderBasicBlock, [3, 4, 6, 3]) + + +def de_resnet50() -> ResNet: + """ResNet-50 model.""" + return _resnet(DecoderBottleneck, [3, 4, 6, 3]) + + +def de_resnet101() -> ResNet: + """ResNet-101 model.""" + return _resnet(DecoderBottleneck, [3, 4, 23, 3]) + + +def de_resnet152() -> ResNet: + """ResNet-152 model.""" + return _resnet(DecoderBottleneck, [3, 8, 36, 3]) + + +def de_resnext50_32x4d() -> ResNet: + """ResNeXt-50 32x4d model.""" + return _resnet(DecoderBottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) + + +def de_resnext101_32x8d() -> ResNet: + """ResNeXt-101 32x8d model.""" + return _resnet(DecoderBottleneck, [3, 4, 23, 3], groups=32, width_per_group=8) + + +def de_wide_resnet50_2() -> ResNet: + """Wide ResNet-50-2 model.""" + return _resnet(DecoderBottleneck, [3, 4, 6, 3], width_per_group=128) + + +def de_wide_resnet101_2() -> ResNet: + """Wide ResNet-101-2 model.""" + return _resnet(DecoderBottleneck, [3, 4, 23, 3], width_per_group=128) + + +def get_decoder(name: str) -> ResNet: + """Get decoder model based on the name of the backbone. + + Args: + name (str): Name of the backbone. + + Returns: + ResNet: Decoder ResNet architecture. + """ + if name in ( + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", + ): + decoder = globals()[f"de_{name}"] + else: + msg = f"Decoder with architecture {name} not supported" + raise ValueError(msg) + return decoder() diff --git a/anomalib/models/image/reverse_distillation/lightning_model.py b/anomalib/models/image/reverse_distillation/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5684e52f1e7f68a522bcc96c44a58b891dd9e6d9 --- /dev/null +++ b/anomalib/models/image/reverse_distillation/lightning_model.py @@ -0,0 +1,133 @@ +"""Anomaly Detection via Reverse Distillation from One-Class Embedding. + +https://arxiv.org/abs/2201.10703v2 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .anomaly_map import AnomalyMapGenerationMode +from .loss import ReverseDistillationLoss +from .torch_model import ReverseDistillationModel + + +class ReverseDistillation(AnomalyModule): + """PL Lightning Module for Reverse Distillation Algorithm. + + Args: + backbone (str): Backbone of CNN network + Defaults to ``wide_resnet50_2``. + layers (list[str]): Layers to extract features from the backbone CNN + Defaults to ``["layer1", "layer2", "layer3"]``. + anomaly_map_mode (AnomalyMapGenerationMode, optional): Mode to generate anomaly map. + Defaults to ``AnomalyMapGenerationMode.ADD``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + """ + + def __init__( + self, + backbone: str = "wide_resnet50_2", + layers: Sequence[str] = ("layer1", "layer2", "layer3"), + anomaly_map_mode: AnomalyMapGenerationMode = AnomalyMapGenerationMode.ADD, + pre_trained: bool = True, + ) -> None: + super().__init__() + + self.backbone = backbone + self.pre_trained = pre_trained + self.layers = layers + self.anomaly_map_mode = anomaly_map_mode + + self.model: ReverseDistillationModel + self.loss = ReverseDistillationLoss() + + def _setup(self) -> None: + if self.input_size is None: + msg = "Input size is required for Reverse Distillation model." + raise ValueError(msg) + + self.model = ReverseDistillationModel( + backbone=self.backbone, + pre_trained=self.pre_trained, + layers=self.layers, + input_size=self.input_size, + anomaly_map_mode=self.anomaly_map_mode, + ) + + def configure_optimizers(self) -> optim.Adam: + """Configure optimizers for decoder and bottleneck. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + return optim.Adam( + params=list(self.model.decoder.parameters()) + list(self.model.bottleneck.parameters()), + lr=0.005, + betas=(0.5, 0.99), + ) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a training step of Reverse Distillation Model. + + Features are extracted from three layers of the Encoder model. These are passed to the bottleneck layer + that are passed to the decoder network. The loss is then calculated based on the cosine similarity between the + encoder and decoder features. + + Args: + batch (batch: dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Feature Map + """ + del args, kwargs # These variables are not used. + + loss = self.loss(*self.model(batch["image"])) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation step of Reverse Distillation Model. + + Similar to the training step, encoder/decoder features are extracted from the CNN for each batch, and + anomaly map is computed. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing images, anomaly maps, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return Reverse Distillation trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/reverse_distillation/loss.py b/anomalib/models/image/reverse_distillation/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..75c9ab5d9852312a697854285590da3033bd3d4d --- /dev/null +++ b/anomalib/models/image/reverse_distillation/loss.py @@ -0,0 +1,38 @@ +"""Loss function for Reverse Distillation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + + +class ReverseDistillationLoss(nn.Module): + """Loss function for Reverse Distillation.""" + + def forward(self, encoder_features: list[torch.Tensor], decoder_features: list[torch.Tensor]) -> torch.Tensor: + """Compute cosine similarity loss based on features from encoder and decoder. + + Based on the official code: + https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/main.py#L33C25-L33C25 + Calculates loss from flattened arrays of features, see https://github.com/hq-deng/RD4AD/issues/22 + + Args: + encoder_features (list[torch.Tensor]): List of features extracted from encoder + decoder_features (list[torch.Tensor]): List of features extracted from decoder + + Returns: + Tensor: Cosine similarity loss + """ + cos_loss = torch.nn.CosineSimilarity() + loss_sum = 0 + for encoder_feature, decoder_feature in zip(encoder_features, decoder_features, strict=True): + loss_sum += torch.mean( + 1 + - cos_loss( + encoder_feature.view(encoder_feature.shape[0], -1), + decoder_feature.view(decoder_feature.shape[0], -1), + ), + ) + return loss_sum diff --git a/anomalib/models/image/reverse_distillation/torch_model.py b/anomalib/models/image/reverse_distillation/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..17ae61a070afdf67e883a4287b953c3d3be8662f --- /dev/null +++ b/anomalib/models/image/reverse_distillation/torch_model.py @@ -0,0 +1,87 @@ +"""PyTorch model for Reverse Distillation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from anomalib.models.components import TimmFeatureExtractor + +from .anomaly_map import AnomalyMapGenerationMode, AnomalyMapGenerator +from .components import get_bottleneck_layer, get_decoder + +if TYPE_CHECKING: + from anomalib.data.utils.tiler import Tiler + + +class ReverseDistillationModel(nn.Module): + """Reverse Distillation Model. + + To reproduce results in the paper, use torchvision model for the encoder: + self.encoder = torchvision.models.wide_resnet50_2(pretrained=True) + + Args: + backbone (str): Name of the backbone used for encoder and decoder. + input_size (tuple[int, int]): Size of input image. + layers (list[str]): Name of layers from which the features are extracted. + anomaly_map_mode (str): Mode used to generate anomaly map. Options are between ``multiply`` and ``add``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + """ + + def __init__( + self, + backbone: str, + input_size: tuple[int, int], + layers: Sequence[str], + anomaly_map_mode: AnomalyMapGenerationMode, + pre_trained: bool = True, + ) -> None: + super().__init__() + self.tiler: Tiler | None = None + + encoder_backbone = backbone + self.encoder = TimmFeatureExtractor(backbone=encoder_backbone, pre_trained=pre_trained, layers=layers) + self.bottleneck = get_bottleneck_layer(backbone) + self.decoder = get_decoder(backbone) + + self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size, mode=anomaly_map_mode) + + def forward(self, images: torch.Tensor) -> torch.Tensor | list[torch.Tensor] | tuple[list[torch.Tensor]]: + """Forward-pass images to the network. + + During the training mode the model extracts features from encoder and decoder networks. + During evaluation mode, it returns the predicted anomaly map. + + Args: + images (torch.Tensor): Batch of images + + Returns: + torch.Tensor | list[torch.Tensor] | tuple[list[torch.Tensor]]: Encoder and decoder features + in training mode, else anomaly maps. + """ + self.encoder.eval() + + if self.tiler: + images = self.tiler.tile(images) + encoder_features = self.encoder(images) + encoder_features = list(encoder_features.values()) + decoder_features = self.decoder(self.bottleneck(encoder_features)) + + if self.tiler: + for i, features in enumerate(encoder_features): + encoder_features[i] = self.tiler.untile(features) + for i, features in enumerate(decoder_features): + decoder_features[i] = self.tiler.untile(features) + + if self.training: + output = encoder_features, decoder_features + else: + output = self.anomaly_map_generator(encoder_features, decoder_features) + + return output diff --git a/anomalib/models/image/rkde/README.md b/anomalib/models/image/rkde/README.md new file mode 100644 index 0000000000000000000000000000000000000000..17beaf27171aca92f3aa9f8d962f35fb052e1fea --- /dev/null +++ b/anomalib/models/image/rkde/README.md @@ -0,0 +1,45 @@ +# Region-Based Kernel Density Estimation (RKDE) + +This is the implementation of the paper [Region Based Anomaly Detection With Real-Time +Training and Analysis](https://ieeexplore.ieee.org/abstract/document/8999287). + +Model Type: Detection + +## Description + +Three-stage anomaly detection consisting of region extraction to obtain a set of region-of-interest proposals for each image, feature extraction to obtain a fixed-length feature vector for each region proposal, and density estimation to classify the region proposals as normal vs. anomalous. + +Both the region extractor and the feature extractor rely on pre-trained convolutional neural networks. The density estimation stage uses Kernel Density Estimation (KDE). + +### Region Extraction + +Region proposals are obtained in the form of bounding boxes by feeding the images through a Faster-RCNN object detector with a ResNet50 backbone, pretrained on MS COCO. Depending on the chosen settings, the region proposals are obtained by taking either the final bounding box predictions of the classification heads, or the region proposals of the Region Proposal Network (RPN). Any detections with the `background` label are discarded, after which the raw region proposals are post-processed by discarding small bounding boxes, applying NMS (across all class labels), and discarding regions with a low confidence score. The minimum region size, IOU threshold used during NMS, and the confidence score threshold can be configured from the config file. + +### Feature Extraction + +The feature extractor consists of a Fast-RCNN model with an AlexNet backbone, which was trained in a multi-task setting on the MS COCO and Visual Genome datasets (see paper for more details). The ROI align layer ensures that the feature maps produced by the convolutional layers are cropped to the bounding box coordinates obtained in the region extraction stage. The activations of the final shared fully connected layer are retrieved to obtain a feature embeddings for each region proposal. + +### Density Estimation + +The classification module uses Kernel Density Estimation (KDE) to estimate the probability density function of the feature space. The KDE model is fitted on the collection of features extracted from the training images. During inference, features extracted from the regions in the inference images are evaluated against the KDE model to obtain a density estimation for each region proposal. The estimates density serves as a 'normality score', which is converted to a normal/anomalous label using Anomalib's thresholding mechanism. + +Before fitting the KDE model, the dimensionality of the feature vectors is reduced using Principal Component Analysis (PCA). Depending on the chosen settings, the features are then scaled to unit vector length or the maximum vector length observed in the training set. + +## Usage and parameters + +`python tools/train.py --model rkde` + +| Parameter | Affects Stage | Description | Type | Options | +| :----------------------- | :----------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----- | :------------ | +| roi_stage | Region Extraction | Processing stage from which the region proposals are retrieved. `rpn`: raw predictions of the region proposal network. `rcnn`: final detection outputs of the classification heads. | string | [rpn, rcnn] | +| roi_score_threshold | Region Extraction | Minimum class score for the region proposals. Regions with a confidence score below this value are discarded. When stage is `rcnn`, class score is used. When stage is `rpn`, objectness score is used. | float | | +| min_box_size | Region Extraction | Minimum size in pixels for the region proposals. Regions with a hight or width smaller than this value will be discarded. | int | | +| iou_threshold | Region Extraction | Intersection-Over-Union threshold used in Non-Maximum-Suppression when post-processing detections. Regions are discarded when their IoU with a higher-confidence region is above this value. | float | | +| max_detections_per_image | Region Extraction | Maximum number of region proposals N allowed per image. When the number of raw proposals is higher than this value, only the top N scoring proposals will be kept. | int | | +| n_pca_components | Density Estimation | Number of principal components to which the features are reduced before applying KDE. | int | | +| max_training_points | Density Estimation | Maximum number of training features on which the KDE model is fitted. When more training features are available, a random selection of features will be discarded. | int | | +| feature_scaling_method | Density Estimation | Determines how the features are scaled before applying KDE. `norm`: the features are normalized to unit vector length. `scale`: The features are normalized to the max vector length observed in training. | string | [norm, scale] | + +## Benchmark + +N/A diff --git a/anomalib/models/image/rkde/__init__.py b/anomalib/models/image/rkde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3918801a72cc5170505168c42782d0230221565f --- /dev/null +++ b/anomalib/models/image/rkde/__init__.py @@ -0,0 +1,9 @@ +"""Region-based Anomaly Detection with Real Time Training and Analysis.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from .lightning_model import Rkde + +__all__ = ["Rkde"] diff --git a/anomalib/models/image/rkde/feature_extractor.py b/anomalib/models/image/rkde/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..117cd71027458f20fe6b1402e3d8706f5fcc4022 --- /dev/null +++ b/anomalib/models/image/rkde/feature_extractor.py @@ -0,0 +1,78 @@ +"""Region-based Anomaly Detection with Real Time Training and Analysis. + +Feature Extractor. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torchvision.ops import RoIAlign +from torchvision.transforms import Normalize, Resize + +from anomalib.data.utils.boxes import scale_boxes + +WEIGHTS_URL = "https://github.com/openvinotoolkit/anomalib/releases/download/rkde-weights/rkde_feature_extractor.pth" + + +class FeatureExtractor(nn.Module): + """Feature Extractor module for Region-based anomaly detection.""" + + def __init__(self) -> None: + super().__init__() + + self.transform = nn.Sequential( + Resize(size=600, max_size=1000), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ) + + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.roi_align = RoIAlign(output_size=(6, 6), spatial_scale=1 / 16, sampling_ratio=0) + + self.classifier = nn.Sequential( + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + ) + + # load the pre-trained weights from url + self.load_state_dict(torch.hub.load_state_dict_from_url(WEIGHTS_URL, progress=False)) + + @torch.no_grad() + def forward(self, batch: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """Perform a forward pass of the feature extractor. + + Args: + batch (torch.Tensor): Batch of input images of shape [B, C, H, W]. + rois (torch.Tensor): torch.Tensor of shape [N, 5] describing the regions-of-interest in the batch. + + Returns: + Tensor: torch.Tensor containing a 4096-dimensional feature vector for every RoI location. + """ + # Apply the feature extractor transforms + transformed_batch = self.transform(batch) + + # Scale the RoIs to the effective input size of the feature extractor. + rois[:, 1:] = scale_boxes(rois[:, 1:], batch.shape[-2:], transformed_batch.shape[-2:]) + + # Forward pass through the backbone + features = self.features(transformed_batch) + features = self.roi_align(features, rois) + features = torch.flatten(features, 1) + return self.classifier(features) diff --git a/anomalib/models/image/rkde/lightning_model.py b/anomalib/models/image/rkde/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8800a570888cd86386a0695d618821281954b61c --- /dev/null +++ b/anomalib/models/image/rkde/lightning_model.py @@ -0,0 +1,146 @@ +"""Region Based Anomaly Detection With Real-Time Training and Analysis. + +https://ieeexplore.ieee.org/abstract/document/8999287 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin +from anomalib.models.components.classification import FeatureScalingMethod + +from .region_extractor import RoiStage +from .torch_model import RkdeModel + +logger = logging.getLogger(__name__) + + +class Rkde(MemoryBankMixin, AnomalyModule): + """Region Based Anomaly Detection With Real-Time Training and Analysis. + + Args: + roi_stage (RoiStage, optional): Processing stage from which rois are extracted. + Defaults to ``RoiStage.RCNN``. + roi_score_threshold (float, optional): Mimumum confidence score for the region proposals. + Defaults to ``0.001``. + min_size (int, optional): Minimum size in pixels for the region proposals. + Defaults to ``25``. + iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS. + Defaults to ``0.3``. + max_detections_per_image (int, optional): Maximum number of region proposals per image. + Defaults to ``100``. + n_pca_components (int, optional): Number of PCA components. + Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod, optional): Scaling method applied to features before passing to + KDE. Options are `norm` (normalize to unit vector length) and `scale` (scale to max length observed in + training). + Defaults to ``FeatureScalingMethod.SCALE``. + max_training_points (int, optional): Maximum number of training points to fit the KDE model. + Defaults to ``40000``. + """ + + def __init__( + self, + roi_stage: RoiStage = RoiStage.RCNN, + roi_score_threshold: float = 0.001, + min_box_size: int = 25, + iou_threshold: float = 0.3, + max_detections_per_image: int = 100, + n_pca_components: int = 16, + feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE, + max_training_points: int = 40000, + ) -> None: + super().__init__() + + self.model: RkdeModel = RkdeModel( + roi_stage=roi_stage, + roi_score_threshold=roi_score_threshold, + min_box_size=min_box_size, + iou_threshold=iou_threshold, + max_detections_per_image=max_detections_per_image, + n_pca_components=n_pca_components, + feature_scaling_method=feature_scaling_method, + max_training_points=max_training_points, + ) + self.embeddings: list[torch.Tensor] = [] + + @staticmethod + def configure_optimizers() -> None: + """RKDE doesn't require optimization, therefore returns no optimizers.""" + return + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + """Perform a training Step of RKDE. For each batch, features are extracted from the CNN. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Deep CNN features. + """ + del args, kwargs # These variables are not used. + + features = self.model(batch["image"]) + self.embeddings.append(features) + + def fit(self) -> None: + """Fit a KDE Model to the embedding collected from the training set.""" + embeddings = torch.vstack(self.embeddings) + + logger.info("Fitting a KDE model to the embedding collected from the training set.") + self.model.fit(embeddings) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation Step of RKde. + + Similar to the training step, features are extracted from the CNN for each batch. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing probability, prediction and ground truth values. + """ + del args, kwargs # These variables are not used. + + # get batched model predictions + boxes, scores = self.model(batch["image"]) + + # convert batched predictions to list format + image: torch.Tensor = batch["image"] + batch_size = image.shape[0] + indices = boxes[:, 0] + batch["pred_boxes"] = [boxes[indices == i, 1:] for i in range(batch_size)] + batch["box_scores"] = [scores[indices == i] for i in range(batch_size)] + + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return R-KDE trainer arguments. + + Returns: + dict[str, Any]: Arguments for the trainer. + """ + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/rkde/region_extractor.py b/anomalib/models/image/rkde/region_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..4a154da7aea218c135eb60b41f08d955aaf5c6e5 --- /dev/null +++ b/anomalib/models/image/rkde/region_extractor.py @@ -0,0 +1,146 @@ +"""Region-based Anomaly Detection with Real Time Training and Analysis. + +Region Extractor. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +import torch +from torch import nn +from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.ops import boxes as box_ops + +from anomalib.data.utils.boxes import scale_boxes + + +class RoiStage(str, Enum): + """Processing stage from which rois are extracted.""" + + RCNN = "rcnn" + RPN = "rpn" + + +class RegionExtractor(nn.Module): + """Extracts regions from the image. + + Args: + stage (RoiStage, optional): Processing stage from which rois are extracted. + Defaults to ``RoiStage.RCNN``. + score_threshold (float, optional): Mimumum confidence score for the region proposals. + Defaults to ``0.001``. + min_size (int, optional): Minimum size in pixels for the region proposals. + Defaults to ``25``. + iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS. + Defaults to ``0.3``. + max_detections_per_image (int, optional): Maximum number of region proposals per image. + Defaults to ``100``. + """ + + def __init__( + self, + stage: RoiStage = RoiStage.RCNN, + score_threshold: float = 0.001, + min_size: int = 25, + iou_threshold: float = 0.3, + max_detections_per_image: int = 100, + ) -> None: + super().__init__() + + # Affects global behaviour of the region extractor + self.stage = stage + self.min_size = min_size + self.iou_threshold = iou_threshold + self.max_detections_per_image = max_detections_per_image + + # Affects behaviour depending on roi stage + rpn_top_n = max_detections_per_image if self.stage == RoiStage.RPN else 1000 + rpn_score_thresh = score_threshold if self.stage == RoiStage.RPN else 0.0 + + # Create the model + self.faster_rcnn = fasterrcnn_resnet50_fpn( + pretrained=True, + rpn_post_nms_top_n_test=rpn_top_n, + rpn_score_thresh=rpn_score_thresh, + box_score_thresh=score_threshold, + box_nms_thresh=1.0, # this disables nms (we apply custom label-agnostic nms during post-processing) + box_detections_per_img=1000, # this disables filtering top-k predictions (we apply our own after nms) + ) + + @torch.no_grad() + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Forward pass of the model. + + Args: + batch (torch.Tensor): Batch of input images of shape [B, C, H, W]. + + Raises: + ValueError: When ``stage`` is not one of ``rcnn`` or ``rpn``. + + Returns: + Tensor: Predicted regions, tensor of shape [N, 5] where N is the number of predicted regions in the batch, + and where each row describes the index of the image in the batch and the 4 bounding box coordinates. + """ + if self.training: + msg = "Should not be in training mode" + raise ValueError(msg) + + if self.stage == RoiStage.RCNN: + # get rois from rcnn output + predictions = self.faster_rcnn(batch) + all_regions = [prediction["boxes"] for prediction in predictions] + all_scores = [prediction["scores"] for prediction in predictions] + elif self.stage == RoiStage.RPN: + # get rois from region proposal network + images, _ = self.faster_rcnn.transform(batch) + features = self.faster_rcnn.backbone(images.tensors) + proposals, _ = self.faster_rcnn.rpn(images, features) + # post-process raw rpn predictions + all_regions = [box_ops.clip_boxes_to_image(boxes, images.tensors.shape[-2:]) for boxes in proposals] + all_regions = [scale_boxes(boxes, images.tensors.shape[-2:], batch.shape[-2:]) for boxes in all_regions] + all_scores = [torch.ones(boxes.shape[0]).to(boxes.device) for boxes in all_regions] + else: + msg = f"Unknown region extractor stage: {self.stage}" + raise ValueError(msg) + + regions = self.post_process_box_predictions(all_regions, all_scores) + + # convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2] + indices = torch.repeat_interleave( + torch.arange(len(regions)), + torch.Tensor([rois.shape[0] for rois in regions]).int(), + ) + return torch.cat([indices.unsqueeze(1).to(batch.device), torch.cat(regions)], dim=1) + + def post_process_box_predictions(self, pred_boxes: torch.Tensor, pred_scores: torch.Tensor) -> list[torch.Tensor]: + """Post-processes the box predictions. + + The post-processing consists of removing small boxes, applying nms, and + keeping only the k boxes with the highest confidence score. + + Args: + pred_boxes (torch.Tensor): Box predictions of shape (N, 4). + pred_scores (torch.Tensor): torch.Tensor of shape () with a confidence score for each box prediction. + + Returns: + list[torch.Tensor]: Post-processed box predictions of shape (N, 4). + """ + processed_boxes_list: list[torch.Tensor] = [] + for boxes, scores in zip(pred_boxes, pred_scores, strict=True): + # remove small boxes + keep = box_ops.remove_small_boxes(boxes, min_size=self.min_size) + processed_boxes, processed_scores = boxes[keep], scores[keep] + + # non-maximum suppression, all boxes together + keep = box_ops.nms(processed_boxes, processed_scores, self.iou_threshold) + + # keep only top-k scoring predictions + keep = keep[: self.max_detections_per_image] + processed_boxes = processed_boxes[keep] + + processed_boxes_list.append(processed_boxes) + + return processed_boxes_list diff --git a/anomalib/models/image/rkde/torch_model.py b/anomalib/models/image/rkde/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b69cccad533e957c0f1195ac5259cc207131f0b6 --- /dev/null +++ b/anomalib/models/image/rkde/torch_model.py @@ -0,0 +1,115 @@ +"""Torch model for region-based anomaly detection.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +import torch +from torch import nn + +from anomalib.models.components.classification import FeatureScalingMethod, KDEClassifier + +from .feature_extractor import FeatureExtractor +from .region_extractor import RegionExtractor, RoiStage + +logger = logging.getLogger(__name__) + + +class RkdeModel(nn.Module): + """Torch Model for the Region-based Anomaly Detection Model. + + Args: + roi_stage (RoiStage, optional): Processing stage from which rois are extracted. + Defaults to ``RoiStage.RCNN``. + roi_score_threshold (float, optional): Mimumum confidence score for the region proposals. + Defaults to ``0.001``. + min_size (int, optional): Minimum size in pixels for the region proposals. + Defaults to ``25``. + iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS. + Defaults to ``0.3``. + max_detections_per_image (int, optional): Maximum number of region proposals per image. + Defaults to ``100``. + n_pca_components (int, optional): Number of PCA components. + Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod, optional): Scaling method applied to features before passing to + KDE. Options are `norm` (normalize to unit vector length) and `scale` (scale to max length observed in + training). + Defaults to ``FeatureScalingMethod.SCALE``. + max_training_points (int, optional): Maximum number of training points to fit the KDE model. + Defaults to ``40000``. + """ + + def __init__( + self, + # roi params + roi_stage: RoiStage = RoiStage.RCNN, + roi_score_threshold: float = 0.001, + min_box_size: int = 25, + iou_threshold: float = 0.3, + max_detections_per_image: int = 100, + # kde params + n_pca_components: int = 16, + feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE, + max_training_points: int = 40000, + ) -> None: + super().__init__() + + self.region_extractor = RegionExtractor( + stage=roi_stage, + score_threshold=roi_score_threshold, + min_size=min_box_size, + iou_threshold=iou_threshold, + max_detections_per_image=max_detections_per_image, + ).eval() + + self.feature_extractor = FeatureExtractor().eval() + + self.classifier = KDEClassifier( + n_pca_components=n_pca_components, + feature_scaling_method=feature_scaling_method, + max_training_points=max_training_points, + ) + + def fit(self, embeddings: torch.Tensor) -> bool: + """Fit the model using a set of collected embeddings. + + Args: + embeddings (torch.Tensor): Input embeddings to fit the model. + + Returns: + Boolean confirming whether the training is successful. + """ + return self.classifier.fit(embeddings) + + def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Prediction by normality model. + + Args: + batch (torch.Tensor): Input images. + + Returns: + Tensor | tuple[torch.Tensor, torch.Tensor]: The extracted features (when in training mode), + or the predicted rois and corresponding anomaly scores. + """ + self.region_extractor.eval() + self.feature_extractor.eval() + + # 1. apply region extraction + rois = self.region_extractor(batch) + + # 2. apply feature extraction + if rois.shape[0] == 0: + # cannot extract features when no rois are retrieved + features = torch.empty((0, 4096)).to(batch.device) + else: + features = self.feature_extractor(batch, rois.clone()) + + if self.training: + return features + + # 3. apply density estimation + scores = self.classifier(features) + + return rois, scores diff --git a/anomalib/models/image/stfpm/README.md b/anomalib/models/image/stfpm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6b83dad775f4f7ff8b1769ebbf87e2ff819a4981 --- /dev/null +++ b/anomalib/models/image/stfpm/README.md @@ -0,0 +1,54 @@ +# Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection + +This is the implementation of the [STFPM](https://arxiv.org/pdf/2103.04257.pdf) paper. + +Model Type: Segmentation + +## Description + +STFPM algorithm consists of a pre-trained teacher network and a student network with identical architecture. The student network learns the distribution of anomaly-free images by matching the features with the counterpart features in the teacher network. Multi-scale feature matching is used to enhance robustness. This hierarchical feature matching enables the student network to receive a mixture of multi-level knowledge from the feature pyramid thus allowing for anomaly detection of various sizes. + +During inference, the feature pyramids of teacher and student networks are compared. Larger difference indicates a higher probability of anomaly occurrence. + +## Architecture + +![STFPM Architecture](/docs/source/images/stfpm/architecture.jpg "STFPM Architecture") + +## Usage + +`python tools/train.py --model stfpm` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.893 | 0.954 | 0.982 | 0.989 | 0.949 | 0.961 | 0.979 | 0.838 | 0.759 | 0.999 | 0.956 | 0.705 | 0.835 | 0.997 | 0.853 | 0.645 | +| Wide ResNet-50 | 0.876 | 0.957 | 0.977 | 0.981 | 0.976 | 0.939 | 0.987 | 0.878 | 0.732 | 0.995 | 0.973 | 0.652 | 0.825 | 0.5 | 0.875 | 0.899 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.951 | 0.986 | 0.988 | 0.991 | 0.946 | 0.949 | 0.971 | 0.898 | 0.962 | 0.981 | 0.942 | 0.878 | 0.983 | 0.983 | 0.838 | 0.972 | +| Wide ResNet-50 | 0.903 | 0.987 | 0.989 | 0.980 | 0.966 | 0.956 | 0.966 | 0.913 | 0.956 | 0.974 | 0.961 | 0.946 | 0.988 | 0.178 | 0.807 | 0.980 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-18 | 0.932 | 0.961 | 0.982 | 0.989 | 0.930 | 0.951 | 0.984 | 0.819 | 0.918 | 0.993 | 0.973 | 0.918 | 0.887 | 0.984 | 0.790 | 0.908 | +| Wide ResNet-50 | 0.926 | 0.973 | 0.973 | 0.974 | 0.965 | 0.929 | 0.976 | 0.853 | 0.920 | 0.972 | 0.974 | 0.922 | 0.884 | 0.833 | 0.815 | 0.931 | + +### Sample Results + +![Sample Result 1](/docs/source/images/stfpm/results/0.png "Sample Result 1") + +![Sample Result 2](/docs/source/images/stfpm/results/1.png "Sample Result 2") + +![Sample Result 3](/docs/source/images/stfpm/results/2.png "Sample Result 3") diff --git a/anomalib/models/image/stfpm/__init__.py b/anomalib/models/image/stfpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..049695a63e5409d48c1bb414b9909576a71bf1ce --- /dev/null +++ b/anomalib/models/image/stfpm/__init__.py @@ -0,0 +1,8 @@ +"""STFPM Model.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Stfpm + +__all__ = ["Stfpm"] diff --git a/anomalib/models/image/stfpm/anomaly_map.py b/anomalib/models/image/stfpm/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..1e44feae961e12fa2a1e02a23ab9cbb5dc9d8a1a --- /dev/null +++ b/anomalib/models/image/stfpm/anomaly_map.py @@ -0,0 +1,95 @@ +"""Anomaly Map Generator for the STFPM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap.""" + + def __init__(self) -> None: + super().__init__() + self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) + + def compute_layer_map( + self, + teacher_features: torch.Tensor, + student_features: torch.Tensor, + image_size: tuple[int, int] | torch.Size, + ) -> torch.Tensor: + """Compute the layer map based on cosine similarity. + + Args: + teacher_features (torch.Tensor): Teacher features + student_features (torch.Tensor): Student features + image_size (tuple[int, int]): Image size to which the anomaly map should be resized. + + Returns: + Anomaly score based on cosine similarity. + """ + norm_teacher_features = F.normalize(teacher_features) + norm_student_features = F.normalize(student_features) + + layer_map = 0.5 * torch.norm(norm_teacher_features - norm_student_features, p=2, dim=-3, keepdim=True) ** 2 + return F.interpolate(layer_map, size=image_size, align_corners=False, mode="bilinear") + + def compute_anomaly_map( + self, + teacher_features: dict[str, torch.Tensor], + student_features: dict[str, torch.Tensor], + image_size: tuple[int, int] | torch.Size, + ) -> torch.Tensor: + """Compute the overall anomaly map via element-wise production the interpolated anomaly maps. + + Args: + teacher_features (dict[str, torch.Tensor]): Teacher features + student_features (dict[str, torch.Tensor]): Student features + image_size (tuple[int, int]): Image size to which the anomaly map should be resized. + + Returns: + Final anomaly map + """ + batch_size = next(iter(teacher_features.values())).shape[0] + anomaly_map = torch.ones(batch_size, 1, image_size[0], image_size[1]) + for layer in teacher_features: + layer_map = self.compute_layer_map(teacher_features[layer], student_features[layer], image_size) + anomaly_map = anomaly_map.to(layer_map.device) + anomaly_map *= layer_map + + return anomaly_map + + def forward(self, **kwargs: dict[str, torch.Tensor]) -> torch.Tensor: + """Return anomaly map. + + Expects `teach_features` and `student_features` keywords to be passed explicitly. + + Args: + kwargs (dict[str, torch.Tensor]): Keyword arguments + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size)) + >>> output = self.anomaly_map_generator( + teacher_features=teacher_features, + student_features=student_features + ) + + Raises: + ValueError: `teach_features` and `student_features` keys are not found + + Returns: + torch.Tensor: anomaly map + """ + if not ("teacher_features" in kwargs and "student_features" in kwargs): + msg = f"Expected keys `teacher_features` and `student_features. Found {kwargs.keys()}" + raise ValueError(msg) + + teacher_features: dict[str, torch.Tensor] = kwargs["teacher_features"] + student_features: dict[str, torch.Tensor] = kwargs["student_features"] + image_size: tuple[int, int] | torch.Size = kwargs["image_size"] + + return self.compute_anomaly_map(teacher_features, student_features, image_size) diff --git a/anomalib/models/image/stfpm/lightning_model.py b/anomalib/models/image/stfpm/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..59cc5df98d403e97ef8043a7ceb54a846e830a8f --- /dev/null +++ b/anomalib/models/image/stfpm/lightning_model.py @@ -0,0 +1,114 @@ +"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. + +https://arxiv.org/abs/2103.04257 +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import STFPMLoss +from .torch_model import STFPMModel + +__all__ = ["Stfpm"] + + +class Stfpm(AnomalyModule): + """PL Lightning Module for the STFPM algorithm. + + Args: + backbone (str): Backbone CNN network + Defaults to ``resnet18``. + layers (list[str]): Layers to extract features from the backbone CNN + Defaults to ``["layer1", "layer2", "layer3"]``. + """ + + def __init__( + self, + backbone: str = "resnet18", + layers: Sequence[str] = ("layer1", "layer2", "layer3"), + ) -> None: + super().__init__() + + self.model = STFPMModel( + backbone=backbone, + layers=layers, + ) + self.loss = STFPMLoss() + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a training step of STFPM. + + For each batch, teacher and student and teacher features are extracted from the CNN. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch. + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + Loss value + """ + del args, kwargs # These variables are not used. + + teacher_features, student_features = self.model.forward(batch["image"]) + loss = self.loss(teacher_features, student_features) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation Step of STFPM. + + Similar to the training step, student/teacher features are extracted from the CNN for each batch, and + anomaly map is computed. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments + kwargs: Additional keyword arguments + + Returns: + Dictionary containing images, anomaly maps, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + del args, kwargs # These variables are not used. + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Required trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers. + + Returns: + Optimizer: SGD optimizer + """ + return optim.SGD( + params=self.model.student_model.parameters(), + lr=0.4, + momentum=0.9, + dampening=0.0, + weight_decay=0.001, + ) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/anomalib/models/image/stfpm/loss.py b/anomalib/models/image/stfpm/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..df1e5012b0d670d4ec90d79b02154dc506de13f1 --- /dev/null +++ b/anomalib/models/image/stfpm/loss.py @@ -0,0 +1,71 @@ +"""Loss function for the STFPM Model Implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class STFPMLoss(nn.Module): + """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. + + Example: + >>> from anomalib.models.components.feature_extractors import TimmFeatureExtractor + >>> from anomalib.models.stfpm.loss import STFPMLoss + >>> from torchvision.models import resnet18 + + >>> layers = ['layer1', 'layer2', 'layer3'] + >>> teacher_model = TimmFeatureExtractor(model=resnet18(pretrained=True), layers=layers) + >>> student_model = TimmFeatureExtractor(model=resnet18(pretrained=False), layers=layers) + >>> loss = Loss() + + >>> inp = torch.rand((4, 3, 256, 256)) + >>> teacher_features = teacher_model(inp) + >>> student_features = student_model(inp) + >>> loss(student_features, teacher_features) + tensor(51.2015, grad_fn=) + """ + + def __init__(self) -> None: + super().__init__() + self.mse_loss = nn.MSELoss(reduction="sum") + + def compute_layer_loss(self, teacher_feats: torch.Tensor, student_feats: torch.Tensor) -> torch.Tensor: + """Compute layer loss based on Equation (1) in Section 3.2 of the paper. + + Args: + teacher_feats (torch.Tensor): Teacher features + student_feats (torch.Tensor): Student features + + Returns: + L2 distance between teacher and student features. + """ + height, width = teacher_feats.shape[2:] + + norm_teacher_features = F.normalize(teacher_feats) + norm_student_features = F.normalize(student_feats) + return (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) + + def forward( + self, + teacher_features: dict[str, torch.Tensor], + student_features: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. + + Args: + teacher_features (dict[str, torch.Tensor]): Teacher features + student_features (dict[str, torch.Tensor]): Student features + + Returns: + Total loss, which is the weighted average of the layer losses. + """ + layer_losses: list[torch.Tensor] = [] + for layer in teacher_features: + loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) + layer_losses.append(loss) + + return torch.stack(layer_losses).sum() diff --git a/anomalib/models/image/stfpm/torch_model.py b/anomalib/models/image/stfpm/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5b80a6ec7a346c5fe26d752d0a704c1e39fb851e --- /dev/null +++ b/anomalib/models/image/stfpm/torch_model.py @@ -0,0 +1,85 @@ +"""PyTorch model for the STFPM model implementation.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from anomalib.models.components import TimmFeatureExtractor + +from .anomaly_map import AnomalyMapGenerator + +if TYPE_CHECKING: + from anomalib.data.utils.tiler import Tiler + + +class STFPMModel(nn.Module): + """STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. + + Args: + layers (list[str]): Layers used for feature extraction. + backbone (str, optional): Pre-trained model backbone. + Defaults to ``resnet18``. + """ + + def __init__( + self, + layers: Sequence[str], + backbone: str = "resnet18", + ) -> None: + super().__init__() + self.tiler: Tiler | None = None + + self.backbone = backbone + self.teacher_model = TimmFeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers).eval() + self.student_model = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=False, + layers=layers, + requires_grad=True, + ) + + # teacher model is fixed + for parameters in self.teacher_model.parameters(): + parameters.requires_grad = False + + self.anomaly_map_generator = AnomalyMapGenerator() + + def forward(self, images: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor]]: + """Forward-pass images into the network. + + During the training mode the model extracts the features from the teacher and student networks. + During the evaluation mode, it returns the predicted anomaly map. + + Args: + images (torch.Tensor): Batch of images. + + Returns: + Teacher and student features when in training mode, otherwise the predicted anomaly maps. + """ + output_size = images.shape[-2:] + if self.tiler: + images = self.tiler.tile(images) + teacher_features: dict[str, torch.Tensor] = self.teacher_model(images) + student_features: dict[str, torch.Tensor] = self.student_model(images) + + if self.tiler: + for layer, data in teacher_features.items(): + teacher_features[layer] = self.tiler.untile(data) + for layer, data in student_features.items(): + student_features[layer] = self.tiler.untile(data) + + if self.training: + output = teacher_features, student_features + else: + output = self.anomaly_map_generator( + teacher_features=teacher_features, + student_features=student_features, + image_size=output_size, + ) + + return output diff --git a/anomalib/models/image/uflow/README.md b/anomalib/models/image/uflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..455bee1eb275c56f42fad4f7176079dfb2417e36 --- /dev/null +++ b/anomalib/models/image/uflow/README.md @@ -0,0 +1,128 @@ +# U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold + +[//]: # "This is the implementation of the [U-Flow](https://arxiv.org/abs/2211.12353) paper, based on the [original code](https://www.github.com/mtailanian/uflow)" + +This is the implementation of the [U-Flow](https://www.researchsquare.com/article/rs-3367286/latest) paper, based on the [original code](https://www.github.com/mtailanian/uflow) + +![U-Flow Architecture](/docs/source/images/uflow/diagram.png "U-Flow Architecture") + +## Abstract + +_In this work we propose a one-class self-supervised method for anomaly segmentation in images, that benefits both from a modern machine learning approach and a more classic statistical detection theory. +The method consists of three phases. First, features are extracted using a multi-scale image Transformer architecture. Then, these features are fed into a U-shaped Normalizing Flow that lays the theoretical foundations for the last phase, which computes a pixel-level anomaly map and performs a segmentation based on the a contrario framework. +This multiple-hypothesis testing strategy permits the derivation of robust automatic detection thresholds, which are crucial in real-world applications where an operational point is needed. +The segmentation results are evaluated using the Intersection over Union (IoU) metric, and for assessing the generated anomaly maps we report the area under the Receiver Operating Characteristic curve (AUROC), and the area under the per-region-overlap curve (AUPRO). +Extensive experimentation in various datasets shows that the proposed approach produces state-of-the-art results for all metrics and all datasets, ranking first in most MvTec-AD categories, with a mean pixel-level AUROC of 98.74%._ + +![Teaser image](/docs/source/images/uflow/teaser.jpg) + +## Localization results + +### Pixel AUROC over MVTec-AD Dataset + +![Pixel-AUROC results](/docs/source/images/uflow/pixel-auroc.png "Pixel-AUROC results") + +### Pixel AUPRO over MVTec-AD Dataset + +![Pixel-AUPRO results](/docs/source/images/uflow/pixel-aupro.png "Pixel-AUPRO results") + +## Segmentation results (IoU) with threshold log(NFA)=0 + +This paper also proposes a method to automatically compute the threshold using the a contrario framework. All results below are obtained with the threshold log(NFA)=0. +In the default code here, for the sake of comparison with all the other methods of the library, the segmentation is done computing the threshold over the anomaly map at train time. +Nevertheless, the code for computing the segmentation mask with the NFA criterion is included in the `src/anomalib/models/uflow/anomaly_map.py`. + +![IoU results](/docs/source/images/uflow/iou.png "IoU results") + +## Results over other datasets + +![Results over other datasets](/docs/source/images/uflow/more-results.png "Results over other datasets") + +## Benchmarking + +Note that the proposed method uses the MCait Feature Extractor, which has an input size of 448x448. In the benchmarking, a size of 256x256 is used for all methods, and therefore the results may differ from those reported. In order to exactly reproduce all results, the reader can refer to the original code (see [here](https://www.github.com/mtailanian/uflow), where the configs used and even the trained checkpoints can be downloaded from [this release](https://github.com/mtailanian/uflow/releases/tag/trained-models-for-all-mvtec-categories). + +## Reproducing paper's results + +Using the default parameters of the config file (`src/anomalib/models/uflow/config.yaml`), the results obtained are very close to the ones reported in the paper: + +bottle: 97.98, cable: 98.17, capsule: 98.95, carpet: 99.45, grid: 98.19, hazelnut: 99.01, leather: 99.41, metal_nut: 98.19, pill: 99.15, screw: 99.25, tile: 96.93, toothbrush: 98.97, transistor: 96.70, wood: 96.87, zipper: 97.92 + +In order to obtain the same exact results, although the architecture parameters stays always the same, the following values for the learning rate and batch size should be used (please refer to the [original code](https://www.github.com/mtailanian/uflow) for more details, where the used configs are available in the source code ([here](https://github.com/mtailanian/uflow/tree/main/configs)), and trained checkpoints are available in [this release](https://github.com/mtailanian/uflow/releases/tag/trained-models-for-all-mvtec-categories)): + +## Usage + +`python tools/train.py --model uflow` + +## Download data + +### MVTec + +https://www.mvtec.com/company/research/datasets/mvtec-ad + +### Bean Tech + +https://paperswithcode.com/dataset/btad + +### LGG MRI + +https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation + +### ShanghaiTech Campus + +https://svip-lab.github.io/dataset/campus_dataset.html + +## [Optional] Download pre-trained models + +Pre-trained models can be found in [this release](https://github.com/mtailanian/uflow/tree/main/configs), or can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing) + +For an easier way of downloading them, please refer to the `README.md` from the [original code](https://www.github.com/mtailanian/uflow) + +For reproducing the exact results from the paper, different learning rates and batch sizes are to be used for each category. You can find the exact values in the `configs` folder, following the [previous link](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing). + +## A note on sizes at different points + +Input + +```text +- Scale 1: [3, 448, 448] +- Scale 2: [3, 224, 224] +``` + +MS-Cait outputs + +```text +- Scale 1: [768, 28, 28] +- Scale 2: [384, 14, 14] +``` + +Normalizing Flow outputs + +```text +- Scale 1: [816, 28, 28] --> 816 = 768 + 384 / 2 / 4 +- Scale 2: [192, 14, 14] --> 192 = 384 / 2 +``` + +`/ 2` corresponds to the split, and `/ 4` to the invertible upsample. + +## Example results + +### Anomalies + +#### MVTec + +![MVTec results - anomalies](/docs/source/images/uflow/results-mvtec-anomalies.jpg "MVTec results - anomalies") + +#### BeanTech, LGG MRI, STC + +![BeanTech, LGG MRI, STC results - anomalies](/docs/source/images/uflow/results-others-anomalies.jpg "BeanTech, LGG MRI, STC results - anomalies") + +### Normal images + +#### MVTec + +![MVTec results - normal](/docs/source/images/uflow/results-mvtec-good.jpg "MVTec results - normal") + +#### BeanTech, LGG MRI, STC + +![BeanTech, LGG MRI, STC results - normal](/docs/source/images/uflow/results-others-good.jpg "BeanTech, LGG MRI, STC results - normal") diff --git a/anomalib/models/image/uflow/__init__.py b/anomalib/models/image/uflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..653f7835fa2fd335f2c81e3b37befc45cbefac55 --- /dev/null +++ b/anomalib/models/image/uflow/__init__.py @@ -0,0 +1,8 @@ +"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Uflow + +__all__ = ["Uflow"] diff --git a/anomalib/models/image/uflow/anomaly_map.py b/anomalib/models/image/uflow/anomaly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..4f445f440cca4fa6ca23de997032091b76ddeb30 --- /dev/null +++ b/anomalib/models/image/uflow/anomaly_map.py @@ -0,0 +1,169 @@ +"""UFlow Anomaly Map Generator Implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import scipy.stats as st +import torch +import torch.nn.functional as F # noqa: N812 +from mpmath import binomial, mp +from omegaconf import ListConfig +from scipy import integrate +from torch import Tensor, nn + +mp.dps = 15 # Set precision for NFA computation (in case of high_precision=True) + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap and segmentation.""" + + def __init__(self, input_size: ListConfig | tuple) -> None: + super().__init__() + self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size) + + def forward(self, latent_variables: list[Tensor]) -> Tensor: + """Return anomaly map.""" + return self.compute_anomaly_map(latent_variables) + + def compute_anomaly_map(self, latent_variables: list[Tensor]) -> Tensor: + """Generate a likelihood-based anomaly map, from latent variables. + + Args: + latent_variables: List of latent variables from the UFlow model. Each element is a tensor of shape + (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and + width of the latent variables, respectively, for each scale l. + + Returns: + Final Anomaly Map. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and + width of the input image, respectively. + """ + likelihoods = [] + for z in latent_variables: + # Mean prob by scale. Likelihood is actually with sum instead of mean. Using mean to avoid numerical issues. + # Also, this way all scales have the same weight, and it does not depend on the number of channels + log_prob_i = -torch.mean(z**2, dim=1, keepdim=True) * 0.5 + prob_i = torch.exp(log_prob_i) + likelihoods.append( + F.interpolate( + prob_i, + size=self.input_size, + mode="bilinear", + align_corners=False, + ), + ) + return 1 - torch.mean(torch.stack(likelihoods, dim=-1), dim=-1) + + def compute_anomaly_mask( + self, + z: list[torch.Tensor], + window_size: int = 7, + binomial_probability_thr: float = 0.5, + high_precision: bool = False, + ) -> torch.Tensor: + """This method is not used in the basic functionality of training and testing. + + It is a bit slow, so we decided to + leave it as an option for the user. It is included as it is part of the U-Flow paper, and can be called + separately if an unsupervised anomaly segmentation is needed. + + Generate an anomaly mask, from latent variables. It is based on the NFA (Number of False Alarms) method, which + is a statistical method to detect anomalies. The NFA is computed as the log of the probability of the null + hypothesis, which is that all pixels are normal. First, we compute a list of candidate pixels, with + suspiciously high values of z^2, by applying a binomial test to each pixel, looking at a window around it. + Then, to compute the NFA values (actually the log-NFA), we evaluate how probable is that a pixel belongs to the + normal distribution. The null-hypothesis is that under normality assumptions, all candidate pixels are uniformly + distributed. Then, the detection is based on the concentration of candidate pixels. + + Args: + z (list[torch.Tensor]): List of latent variables from the UFlow model. Each element is a tensor of shape + (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height + and width of the latent variables, respectively, for each scale l. + window_size (int): Window size for the binomial test. Defaults to 7. + binomial_probability_thr (float): Probability threshold for the binomial test. Defaults to 0.5 + high_precision (bool): Whether to use high precision for the binomial test. Defaults to False. + + Returns: + Anomaly mask. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and + width of the input image, respectively. + """ + log_prob_l = [ + self.binomial_test(zi, window_size / (2**scale), binomial_probability_thr, high_precision) + for scale, zi in enumerate(z) + ] + + log_prob_l_up = torch.cat( + [F.interpolate(lpl, size=self.input_size, mode="bicubic", align_corners=True) for lpl in log_prob_l], + dim=1, + ) + + log_prob = torch.sum(log_prob_l_up, dim=1, keepdim=True) + + log_number_of_tests = torch.log10(torch.sum(torch.tensor([zi.shape[-2] * zi.shape[-1] for zi in z]))) + log_nfa = log_number_of_tests + log_prob + + anomaly_score = -log_nfa + + return anomaly_score < 0 + + @staticmethod + def binomial_test( + z: torch.Tensor, + window_size: int, + probability_thr: float, + high_precision: bool = False, + ) -> torch.Tensor: + """The binomial test applied to validate or reject the null hypothesis that the pixel is normal. + + The null hypothesis is that the pixel is normal, and the alternative hypothesis is that the pixel is anomalous. + The binomial test is applied to a window around the pixel, and the number of pixels in the window that ares + anomalous is compared to the number of pixels that are expected to be anomalous under the null hypothesis. + + Args: + z: Latent variable from the UFlow model. Tensor of shape (N, Cl, Hl, Wl), where N is the batch size, Cl is + the number of channels, and Hl and Wl are the height and width of the latent variables, respectively. + window_size (int): Window size for the binomial test. + probability_thr: Probability threshold for the binomial test. + high_precision: Whether to use high precision for the binomial test. + + Returns: + Log of the probability of the null hypothesis. + + """ + tau = st.chi2.ppf(probability_thr, 1) + half_win = np.max([int(window_size // 2), 1]) + + n_chann = z.shape[1] + + # Candidates + z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu() + z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1) + z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1).numpy() + observed_candidates_k = np.sum(z2_unfold_hw >= tau, axis=(-2, -1)) + + # All volume together + observed_candidates = np.sum(observed_candidates_k, axis=1, keepdims=True) + x = observed_candidates / n_chann + n = int((2 * half_win + 1) ** 2) + + # Low precision + if not high_precision: + log_prob = torch.tensor(st.binom.logsf(x, n, 1 - probability_thr) / np.log(10)) + # High precision - good and slow + else: + to_mp = np.frompyfunc(mp.mpf, 1, 1) + mpn = mp.mpf(n) + mpp = probability_thr + + def binomial_density(tensor: torch.tensor) -> torch.Tensor: + return binomial(mpn, to_mp(tensor)) * (1 - mpp) ** tensor * mpp ** (mpn - tensor) + + def integral(tensor: torch.Tensor) -> torch.Tensor: + return integrate.quad(binomial_density, tensor, n)[0] + + integral_array = np.vectorize(integral) + prob = integral_array(x) + log_prob = torch.tensor(np.log10(prob)) + + return log_prob diff --git a/anomalib/models/image/uflow/feature_extraction.py b/anomalib/models/image/uflow/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6385fc4d09f974cb022d4f3e8ba21b7612f1a8 --- /dev/null +++ b/anomalib/models/image/uflow/feature_extraction.py @@ -0,0 +1,170 @@ +"""Feature Extractor for U-Flow model.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable + +import timm +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn + +from anomalib.models.components.feature_extractors import TimmFeatureExtractor + +AVAILABLE_EXTRACTORS = ["mcait", "resnet18", "wide_resnet50_2"] + + +def get_feature_extractor(backbone: str, input_size: tuple[int, int] = (256, 256)) -> nn.Module: + """Get feature extractor. Currently, is restricted to AVAILABLE_EXTRACTORS. + + Args: + backbone (str): Backbone name. + input_size (tuple[int, int]): Input size. + + Raises: + ValueError if unknown backbone is provided. + + Returns: + FeatureExtractorInterface: Feature extractor. + """ + if backbone not in AVAILABLE_EXTRACTORS: + msg = f"Feature extractor must be one of {AVAILABLE_EXTRACTORS}." + raise ValueError(msg) + + feature_extractor: nn.Module + if backbone in ["resnet18", "wide_resnet50_2"]: + feature_extractor = FeatureExtractor(backbone, input_size, layers=("layer1", "layer2", "layer3")).eval() + if backbone == "mcait": + feature_extractor = MCaitFeatureExtractor().eval() + + return feature_extractor + + +class FeatureExtractor(TimmFeatureExtractor): + """Feature extractor based on ResNet (or others) backbones. + + Args: + backbone (str): Backbone of the feature extractor. + input_size (tuple[int, int]): Input image size used for computing normalization layers. + layers (tuple[str], optional): Layers from which to extract features. + Defaults to ("layer1", "layer2", "layer3"). + """ + + def __init__( + self, + backbone: str, + input_size: tuple[int, int], + layers: tuple[str, ...] = ("layer1", "layer2", "layer3"), + **kwargs, # noqa: ARG002 | unused argument + ) -> None: + super().__init__(backbone, layers, pre_trained=True, requires_grad=False) + self.channels = self.feature_extractor.feature_info.channels() + self.scale_factors = self.feature_extractor.feature_info.reduction() + self.scales = range(len(self.scale_factors)) + + self.feature_normalizations = nn.ModuleList() + for in_channels, scale in zip(self.channels, self.scale_factors, strict=True): + self.feature_normalizations.append( + nn.LayerNorm( + [in_channels, int(input_size[0] / scale), int(input_size[1] / scale)], + elementwise_affine=True, + ), + ) + + for param in self.feature_extractor.parameters(): + param.requires_grad = False + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """Normalized features.""" + features = self.extract_features(img) + return self.normalize_features(features) + + def extract_features(self, img: torch.Tensor) -> torch.Tensor: + """Extract features.""" + self.feature_extractor.eval() + return self.feature_extractor(img) + + def normalize_features(self, features: Iterable[torch.Tensor]) -> list[torch.Tensor]: + """Normalize features.""" + return [self.feature_normalizations[i](feature) for i, feature in enumerate(features)] + + +class MCaitFeatureExtractor(nn.Module): + """Feature extractor based on MCait backbone. + + This is the proposed feature extractor in the paper. It uses two + independently trained Cait models, at different scales, with input sizes 448 and 224, respectively. + It also includes a normalization layer for each scale. + """ + + def __init__(self) -> None: + super().__init__() + self.input_size = 448 + self.extractor1 = timm.create_model("cait_m48_448", pretrained=True) + self.extractor2 = timm.create_model("cait_s24_224", pretrained=True) + self.channels = [768, 384] + self.scale_factors = [16, 32] + self.scales = range(len(self.scale_factors)) + + for param in self.extractor1.parameters(): + param.requires_grad = False + for param in self.extractor2.parameters(): + param.requires_grad = False + + def forward(self, img: torch.Tensor, training: bool = True) -> torch.Tensor: + """Return normalized features.""" + features = self.extract_features(img) + return self.normalize_features(features, training=training) + + def extract_features(self, img: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: # noqa: ARG002 | unused argument + """Extract features from ``img`` from the two extractors. + + Args: + img (torch.Tensor): Input image + kwargs: unused + + Returns: + tuple[torch.Tensor, torch.Tensor]: Features from the two extractors. + """ + self.extractor1.eval() + self.extractor2.eval() + + # Scale 1 --> Extractor 1 + x1 = self.extractor1.patch_embed(img) + x1 = x1 + self.extractor1.pos_embed + x1 = self.extractor1.pos_drop(x1) + for i in range(41): # paper Table 6. Block Index = 40 + x1 = self.extractor1.blocks[i](x1) + + # Scale 2 --> Extractor 2 + img_sub = F.interpolate(torch.Tensor(img), size=(224, 224), mode="bicubic", align_corners=True) + x2 = self.extractor2.patch_embed(img_sub) + x2 = x2 + self.extractor2.pos_embed + x2 = self.extractor2.pos_drop(x2) + for i in range(21): + x2 = self.extractor2.blocks[i](x2) + + return (x1, x2) + + def normalize_features(self, features: torch.Tensor, **kwargs) -> torch.Tensor: # noqa: ARG002 | unused argument + """Normalize features. + + Args: + features (torch.Tensor): Features to normalize. + **kwargs: unused + + Returns: + torch.Tensor: Normalized features. + """ + normalized_features = [] + for i, extractor in enumerate([self.extractor1, self.extractor2]): + batch, _, channels = features[i].shape + scale_factor = self.scale_factors[i] + + x = extractor.norm(features[i].contiguous()) + x = x.permute(0, 2, 1) + x = x.reshape(batch, channels, self.input_size // scale_factor, self.input_size // scale_factor) + normalized_features.append(x) + + return normalized_features diff --git a/anomalib/models/image/uflow/lightning_model.py b/anomalib/models/image/uflow/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc24e27175c21e7e962db7240784ae522f26667 --- /dev/null +++ b/anomalib/models/image/uflow/lightning_model.py @@ -0,0 +1,130 @@ +"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold. + +https://arxiv.org/pdf/2211.12353.pdf +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +import torch +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor +from torch.optim.lr_scheduler import LRScheduler +from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .loss import UFlowLoss +from .torch_model import UflowModel + +logger = logging.getLogger(__name__) + +__all__ = ["Uflow"] + + +class Uflow(AnomalyModule): + """PL Lightning Module for the UFLOW algorithm.""" + + def __init__( + self, + backbone: str = "mcait", + flow_steps: int = 4, + affine_clamp: float = 2.0, + affine_subnet_channels_ratio: float = 1.0, + permute_soft: bool = False, + ) -> None: + """Uflow model. + + Args: + backbone (str): Backbone name. + flow_steps (int): Number of flow steps. + affine_clamp (float): Affine clamp. + affine_subnet_channels_ratio (float): Affine subnet channels ratio. + permute_soft (bool): Whether to use soft permutation. + """ + super().__init__() + + self.backbone = backbone + self.flow_steps = flow_steps + self.affine_clamp = affine_clamp + self.affine_subnet_channels_ratio = affine_subnet_channels_ratio + self.permute_soft = permute_soft + + self.loss = UFlowLoss() + + self.model: UflowModel + + def _setup(self) -> None: + if self.input_size is None: + msg = "Input size is required for UFlow model." + raise ValueError(msg) + + self.model = UflowModel( + input_size=self.input_size, + backbone=self.backbone, + flow_steps=self.flow_steps, + affine_clamp=self.affine_clamp, + affine_subnet_channels_ratio=self.affine_subnet_channels_ratio, + permute_soft=self.permute_soft, + ) + + def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: # noqa: ARG002 | unused arguments + """Training step.""" + z, ljd = self.model(batch["image"]) + loss = self.loss(z, ljd) + self.log_dict({"loss": loss}, on_step=True, on_epoch=False, prog_bar=False, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: # noqa: ARG002 | unused arguments + """Validation step.""" + anomaly_maps = self.model(batch["image"]) + batch["anomaly_maps"] = anomaly_maps + return batch + + def configure_optimizers(self) -> tuple[list[LightningOptimizer], list[LRScheduler]]: + """Return optimizer and scheduler.""" + # Optimizer + # values used in paper: bottle: 0.0001128999, cable: 0.0016160391, capsule: 0.0012118892, carpet: 0.0012118892, + # grid: 0.0000362248, hazelnut: 0.0013268899, leather: 0.0006124724, metal_nut: 0.0008148858, + # pill: 0.0010756100, screw: 0.0004155987, tile: 0.0060457548, toothbrush: 0.0001287313, + # transistor: 0.0011212904, wood: 0.0002466546, zipper: 0.0000455247 + optimizer = torch.optim.Adam([{"params": self.parameters(), "initial_lr": 1e-3}], lr=1e-3, weight_decay=1e-5) + + # Scheduler for slowly reducing learning rate + scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1.0, + end_factor=0.4, + total_iters=25000, + ) + return [optimizer], [scheduler] + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return EfficientAD trainer arguments.""" + return {"num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Default transform for Padim.""" + if image_size is not None: + logger.warning("Image size is not used in UFlow. The input image size is determined by the model.") + return Compose( + [ + Resize((448, 448), antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + ) diff --git a/anomalib/models/image/uflow/loss.py b/anomalib/models/image/uflow/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7afe8a1fc29a8caef9bef6e45a96cf9200bfa67b --- /dev/null +++ b/anomalib/models/image/uflow/loss.py @@ -0,0 +1,24 @@ +"""Loss function for the UFlow Model Implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import Tensor, nn + + +class UFlowLoss(nn.Module): + """UFlow Loss.""" + + def forward(self, hidden_variables: list[Tensor], jacobians: list[Tensor]) -> Tensor: + """Calculate the UFlow loss. + + Args: + hidden_variables (list[Tensor]): Hidden variables from the fastflow model. f: X -> Z + jacobians (list[Tensor]): Log of the jacobian determinants from the fastflow model. + + Returns: + Tensor: UFlow loss computed based on the hidden variables and the log of the Jacobians. + """ + lpz = torch.sum(torch.stack([0.5 * torch.sum(z_i**2, dim=(1, 2, 3)) for z_i in hidden_variables], dim=0)) + return torch.mean(lpz - jacobians) diff --git a/anomalib/models/image/uflow/torch_model.py b/anomalib/models/image/uflow/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbad59bec0436c089d972186e95103ea04c25f1 --- /dev/null +++ b/anomalib/models/image/uflow/torch_model.py @@ -0,0 +1,188 @@ +"""U-Flow torch model.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from FrEIA import framework as ff +from FrEIA import modules as fm +from torch import nn + +from anomalib.models.components.flow import AllInOneBlock + +from .anomaly_map import AnomalyMapGenerator +from .feature_extraction import get_feature_extractor + + +class AffineCouplingSubnet: + """Class for building the Affine Coupling subnet. + + It is passed as an argument to the `AllInOneBlock` module. + + Args: + kernel_size (int): Kernel size. + subnet_channels_ratio (float): Subnet channels ratio. + """ + + def __init__(self, kernel_size: int, subnet_channels_ratio: float) -> None: + self.kernel_size = kernel_size + self.subnet_channels_ratio = subnet_channels_ratio + + def __call__(self, in_channels: int, out_channels: int) -> nn.Sequential: + """Return AffineCouplingSubnet network. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + + Returns: + nn.Sequential: Affine Coupling subnet. + """ + mid_channels = int(in_channels * self.subnet_channels_ratio) + return nn.Sequential( + nn.Conv2d(in_channels, mid_channels, self.kernel_size, padding="same"), + nn.ReLU(), + nn.Conv2d(mid_channels, out_channels, self.kernel_size, padding="same"), + ) + + +class UflowModel(nn.Module): + """U-Flow model. + + Args: + input_size (tuple[int, int]): Input image size. + flow_steps (int): Number of flow steps. + backbone (str): Backbone name. + affine_clamp (float): Affine clamp. + affine_subnet_channels_ratio (float): Affine subnet channels ratio. + permute_soft (bool): Whether to use soft permutation. + """ + + def __init__( + self, + input_size: tuple[int, int] = (448, 448), + flow_steps: int = 4, + backbone: str = "mcait", + affine_clamp: float = 2.0, + affine_subnet_channels_ratio: float = 1.0, + permute_soft: bool = False, + ) -> None: + super().__init__() + + self.input_size = input_size + self.affine_clamp = affine_clamp + self.affine_subnet_channels_ratio = affine_subnet_channels_ratio + self.permute_soft = permute_soft + + self.feature_extractor = get_feature_extractor(backbone, input_size) + self.flow = self.build_flow(flow_steps) + self.anomaly_map_generator = AnomalyMapGenerator(input_size) + + def build_flow(self, flow_steps: int) -> ff.GraphINN: + """Build the flow model. + + First we start with the input nodes, which have to match the feature extractor output. + Then, we build the U-Shaped flow. Starting from the bottom (the coarsest scale), the flow is built as follows: + 1. Pass the input through a Flow Stage (`build_flow_stage`). + 2. Split the output of the flow stage into two parts, one that goes directly to the output, + 3. and the other is up-sampled, and will be concatenated with the output of the next flow stage (next scale) + 4. Repeat steps 1-3 for the next scale. + Finally, we build the Flow graph using the input nodes, the flow stages, and the output nodes. + + Args: + flow_steps (int): Number of flow steps. + + Returns: + ff.GraphINN: Flow model. + """ + input_nodes = [] + for channel, s_factor in zip( + self.feature_extractor.channels, + self.feature_extractor.scale_factors, + strict=True, + ): + input_nodes.append( + ff.InputNode( + channel, + self.input_size[0] // s_factor, + self.input_size[1] // s_factor, + name=f"cond_{channel}", + ), + ) + + nodes, output_nodes = [], [] + last_node = input_nodes[-1] + for i in reversed(range(1, len(input_nodes))): + flows = self.build_flow_stage(last_node, flow_steps) + volume_size = flows[-1].output_dims[0][0] + split = ff.Node( + flows[-1], + fm.Split, + {"section_sizes": (volume_size // 8 * 4, volume_size - volume_size // 8 * 4), "dim": 0}, + name=f"split_{i + 1}", + ) + output = ff.OutputNode(split.out1, name=f"output_scale_{i + 1}") + up = ff.Node(split.out0, fm.IRevNetUpsampling, {}, name=f"up_{i + 1}") + last_node = ff.Node([input_nodes[i - 1].out0, up.out0], fm.Concat, {"dim": 0}, name=f"cat_{i}") + + output_nodes.append(output) + nodes.extend([*flows, split, up, last_node]) + + flows = self.build_flow_stage(last_node, flow_steps) + output = ff.OutputNode(flows[-1], name="output_scale_1") + + output_nodes.append(output) + nodes.extend(flows) + + return ff.GraphINN(input_nodes + nodes + output_nodes[::-1]) + + def build_flow_stage(self, in_node: ff.Node, flow_steps: int, condition_node: ff.Node = None) -> list[ff.Node]: + """Build a flow stage, which is a sequence of flow steps. + + Each flow stage is essentially a sequence of `flow_steps` Glow blocks (`AllInOneBlock`). + + Args: + in_node (ff.Node): Input node. + flow_steps (int): Number of flow steps. + condition_node (ff.Node): Condition node. + + Returns: + List[ff.Node]: List of flow steps. + """ + flow_size = in_node.output_dims[0][-1] + nodes = [] + for step in range(flow_steps): + nodes.append( + ff.Node( + in_node, + AllInOneBlock, + module_args={ + "subnet_constructor": AffineCouplingSubnet( + 3 if step % 2 == 0 else 1, + self.affine_subnet_channels_ratio, + ), + "affine_clamping": self.affine_clamp, + "permute_soft": self.permute_soft, + }, + conditions=condition_node, + name=f"flow{flow_size}_step{step}", + ), + ) + in_node = nodes[-1] + return nodes + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """Return anomaly map.""" + features = self.feature_extractor(image) + z, ljd = self.encode(features) + + if self.training: + return z, ljd + return self.anomaly_map_generator(z) + + def encode(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Return""" + z, ljd = self.flow(features, rev=False) + if len(self.feature_extractor.scales) == 1: + z = [z] + return z, ljd diff --git a/anomalib/models/image/winclip/README.md b/anomalib/models/image/winclip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..24a1fa8e4ea3dc1e54e5e4299a3b83d24a08316e --- /dev/null +++ b/anomalib/models/image/winclip/README.md @@ -0,0 +1,78 @@ +# WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation + +This is the implementation of the [WinCLIP](https://arxiv.org/pdf/2303.14814.pdf) paper. + +Model Type: Segmentation + +## Description + +WinCLIP is a zero-shot/few-shot model for anomaly classification and segmentation. WinCLIP uses a pre-trained [CLIP](https://arxiv.org/pdf/2210.08901.pdf) model to extract image embeddings from the input images, and text embeddings from a set of pre-defined prompts describing the normal and anomalous states of the object class (e.g. "transistor without defect", "transistor with defect"). The image-level anomaly scores are obtained by computing the cosine similarity between the image embeddings and the normal and anomalous text embeddings. + +In addition, WinCLIP performs pixel-level anomaly localization by repeating the anomaly score computation for different local regions of the image. This is achieved by moving a mask over the image in a sliding window fashion. The size of the mask can be varied to include different scales in the localization predictions. The similarity scores of the masked image is assigned to all the pixels in the masked region, after which the scores are aggregated across scales and window locations using harmonic averaging. + +In few-shot mode, a reference association module is introduced, which collects and stores the (window-based) image embeddings of a selection of normal reference images. During inference, an additional association score is computed between as the cosine similarity between the embeddings of the input images and the normal reference images. The final anomaly score is the average of the zero-shot anomaly score and the few-shot association score. + +## Architecture + +![WinCLIP Architecture](/docs/source/images/winclip/architecture.png "WinCLIP Architecture") + +## Usage + +WinCLIP is a zero-shot model, which means that we can directly evaluate the model on a test set without training or fine-tuning on normal images. + +### 0-Shot + +`anomalib test --model WinClip --data MVTec` + +### 1-Shot + +`anomalib test --model WinClip --model.k_shot 1 --data MVTec` + +## Parameters + +| Parameter | Type | Description | Default | +| :--------- | :---- | :------------------------------------------------------------------------------------------------------------------------------------------------------------ | :------- | +| class_name | str | Class name used in the prompt ensemble. When left empty, the category name from the dataset will be used if available, otherwise it will default to `object`. | `null` | +| k_shot | int | Number of normal reference images used in few-shot mode. | `0` | +| scales | tuple | Scales to be included in the multiscale window-embeddings. Each scale is an integer which indicates the window size in number of patches. | `[2, 3]` | + +## Benchmark + +Coming soon... + + + + + + + + + + + +## Attribution + +The implementation of the torch model was inspired by https://github.com/zqhang/WinCLIP-pytorch and https://github.com/caoyunkang/WinClip. diff --git a/anomalib/models/image/winclip/__init__.py b/anomalib/models/image/winclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8435a3c1aa2b9e108a8a8dd420d6f1e549b3362f --- /dev/null +++ b/anomalib/models/image/winclip/__init__.py @@ -0,0 +1,9 @@ +"""WinCLIP Model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import WinClip +from .torch_model import WinClipModel + +__all__ = ["WinClip", "WinClipModel"] diff --git a/anomalib/models/image/winclip/lightning_model.py b/anomalib/models/image/winclip/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0d86697fafc6aff700afeb0ec5ce634446144b41 --- /dev/null +++ b/anomalib/models/image/winclip/lightning_model.py @@ -0,0 +1,180 @@ +"""WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation. + +Paper https://arxiv.org/abs/2303.14814 +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Any + +import torch +from torch.utils.data import DataLoader +from torchvision.transforms.v2 import Compose, InterpolationMode, Normalize, Resize, Transform + +from anomalib import LearningType +from anomalib.data.predict import PredictDataset +from anomalib.models.components import AnomalyModule + +from .torch_model import WinClipModel + +logger = logging.getLogger(__name__) + +__all__ = ["WinClip"] + + +class WinClip(AnomalyModule): + """WinCLIP Lightning model. + + Args: + class_name (str, optional): The name of the object class used in the prompt ensemble. + Defaults to ``None``. + k_shot (int): The number of reference images for few-shot inference. + Defaults to ``0``. + scales (tuple[int], optional): The scales of the sliding windows used for multiscale anomaly detection. + Defaults to ``(2, 3)``. + few_shot_source (str | Path, optional): Path to a folder of reference images used for few-shot inference. + Defaults to ``None``. + """ + + EXCLUDE_FROM_STATE_DICT = frozenset({"model.clip"}) + + def __init__( + self, + class_name: str | None = None, + k_shot: int = 0, + scales: tuple = (2, 3), + few_shot_source: Path | str | None = None, + ) -> None: + super().__init__() + self.model = WinClipModel(scales=scales, apply_transform=False) + self.class_name = class_name + self.k_shot = k_shot + self.few_shot_source = Path(few_shot_source) if few_shot_source else None + + def _setup(self) -> None: + """Setup WinCLIP. + + - Set the class name used in the prompt ensemble. + - Collect text embeddings for zero-shot inference. + - Collect reference images for few-shot inference. + + We need to pass the device because this hook is called before the model is moved to the device. + """ + # get class name + self.class_name = self._get_class_name() + ref_images = None + + # get reference images + if self.k_shot: + if self.few_shot_source: + logger.info("Loading reference images from %s", self.few_shot_source) + reference_dataset = PredictDataset(self.few_shot_source, transform=self.model.transform) + dataloader = DataLoader(reference_dataset, batch_size=1, shuffle=False) + else: + logger.info("Collecting reference images from training dataset") + dataloader = self.trainer.datamodule.train_dataloader() + + ref_images = self.collect_reference_images(dataloader) + + # call setup to initialize the model + self.model.setup(self.class_name, ref_images) + + def _get_class_name(self) -> str: + """Set the class name used in the prompt ensemble. + + - When a class name is provided by the user, it is used. + - When the user did not provide a class name, the category name from the datamodule is used, if available. + - When the user did not provide a class name and the datamodule does not have a category name, the default + class name "object" is used. + """ + if self.class_name is not None: + logger.info("Using class name from init args: %s", self.class_name) + return self.class_name + if getattr(self, "_trainer", None) and hasattr(self.trainer.datamodule, "category"): + logger.info("No class name provided, using category from datamodule: %s", self.trainer.datamodule.category) + return self.trainer.datamodule.category + logger.info("No class name provided and no category name found in datamodule using default: object") + return "object" + + def collect_reference_images(self, dataloader: DataLoader) -> torch.Tensor: + """Collect reference images for few-shot inference. + + The reference images are collected by iterating the training dataset until the required number of images are + collected. + + Returns: + ref_images (Tensor): A tensor containing the reference images. + """ + ref_images = torch.Tensor() + for batch in dataloader: + images = batch["image"][: self.k_shot - ref_images.shape[0]] + ref_images = torch.cat((ref_images, images)) + if self.k_shot == ref_images.shape[0]: + break + return ref_images + + @staticmethod + def configure_optimizers() -> None: + """WinCLIP doesn't require optimization, therefore returns no optimizers.""" + return + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict: + """Validation Step of WinCLIP.""" + del args, kwargs # These variables are not used. + batch["pred_scores"], batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, int | float]: + """Set model-specific trainer arguments.""" + return {} + + @property + def learning_type(self) -> LearningType: + """The learning type of the model. + + WinCLIP is a zero-/few-shot model, depending on the user configuration. Therefore, the learning type is + set to ``LearningType.FEW_SHOT`` when ``k_shot`` is greater than zero and ``LearningType.ZERO_SHOT`` otherwise. + """ + return LearningType.FEW_SHOT if self.k_shot else LearningType.ZERO_SHOT + + def state_dict(self) -> OrderedDict[str, Any]: + """Return the state dict of the model. + + Before returning the state dict, we remove the parameters of the frozen backbone to reduce the size of the + checkpoint. + """ + state_dict = super().state_dict() + for pattern in self.EXCLUDE_FROM_STATE_DICT: + remove_keys = [key for key in state_dict if key.startswith(pattern)] + for key in remove_keys: + state_dict.pop(key) + return state_dict + + def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401 + """Load the state dict of the model. + + Before loading the state dict, we restore the parameters of the frozen backbone to ensure that the model + is loaded correctly. We also restore the auxiliary objects like threshold classes and normalization metrics. + """ + # restore the parameters of the excluded modules, if any + full_dict = super().state_dict() + for pattern in self.EXCLUDE_FROM_STATE_DICT: + restore_dict = {key: value for key, value in full_dict.items() if key.startswith(pattern)} + state_dict.update(restore_dict) + return super().load_state_dict(state_dict, strict) + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: + """Configure the default transforms used by the model.""" + if image_size is not None: + logger.warning("Image size is not used in WinCLIP. The input image size is determined by the model.") + return Compose( + [ + Resize((240, 240), antialias=True, interpolation=InterpolationMode.BICUBIC), + Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)), + ], + ) diff --git a/anomalib/models/image/winclip/prompting.py b/anomalib/models/image/winclip/prompting.py new file mode 100644 index 0000000000000000000000000000000000000000..f33a63d1f42fe95b9459d2fb0a5b37e4b6e58800 --- /dev/null +++ b/anomalib/models/image/winclip/prompting.py @@ -0,0 +1,71 @@ +"""Compositional prompt ensemble for WinCLIP.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +NORMAL_STATES = [ + "{}", + "flawless {}", + "perfect {}", + "unblemished {}", + "{} without flaw", + "{} without defect", + "{} without damage", +] + +ANOMALOUS_STATES = [ + "damaged {}", + "{} with flaw", + "{} with defect", + "{} with damage", +] + +TEMPLATES = [ + "a cropped photo of the {}.", + "a close-up photo of a {}.", + "a close-up photo of the {}.", + "a bright photo of a {}.", + "a bright photo of the {}.", + "a dark photo of the {}.", + "a dark photo of a {}.", + "a jpeg corrupted photo of the {}.", + "a jpeg corrupted photo of the {}.", + "a blurry photo of the {}.", + "a blurry photo of a {}.", + "a photo of a {}.", + "a photo of the {}.", + "a photo of a small {}.", + "a photo of the small {}.", + "a photo of a large {}.", + "a photo of the large {}.", + "a photo of the {} for visual inspection.", + "a photo of a {} for visual inspection.", + "a photo of the {} for anomaly detection.", + "a photo of a {} for anomaly detection.", +] + + +def create_prompt_ensemble(class_name: str = "object") -> tuple[list[str], list[str]]: + """Create prompt ensemble for WinCLIP. + + All combinations of states and templates are generated for both normal and anomalous prompts. + + Args: + class_name (str): Name of the object. + + Returns: + tuple[list[str], list[str]]: Tuple containing the normal and anomalous prompts. + + Examples: + >>> normal_prompts, anomalous_prompts = create_prompt_ensemble("bottle") + >>> normal_prompts[:2] + ['a cropped photo of the bottle.', 'a close-up photo of a bottle.'] + >>> anomalous_prompts[:2] + ['a cropped photo of the damaged bottle.', 'a close-up photo of a damaged bottle.'] + """ + normal_states = [state.format(class_name) for state in NORMAL_STATES] + normal_ensemble = [template.format(state) for state in normal_states for template in TEMPLATES] + + anomalous_states = [state.format(class_name) for state in ANOMALOUS_STATES] + anomalous_ensemble = [template.format(state) for state in anomalous_states for template in TEMPLATES] + return normal_ensemble, anomalous_ensemble diff --git a/anomalib/models/image/winclip/torch_model.py b/anomalib/models/image/winclip/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dbc10b515693abda9ba74ef7a9d7bd87994bde --- /dev/null +++ b/anomalib/models/image/winclip/torch_model.py @@ -0,0 +1,398 @@ +"""PyTorch model for the WinCLIP implementation.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable +from copy import copy + +import open_clip +import torch +from open_clip.tokenizer import tokenize +from torch import nn +from torch.nn.modules.linear import Identity +from torchvision.transforms import Compose, ToPILImage + +from anomalib.models.components import BufferListMixin, DynamicBufferMixin + +from .prompting import create_prompt_ensemble +from .utils import class_scores, harmonic_aggregation, make_masks, visual_association_score + +BACKBONE = "ViT-B-16-plus-240" +PRETRAINED = "laion400m_e31" +TEMPERATURE = 0.07 # temperature hyperparameter from the clip paper + + +class WinClipModel(DynamicBufferMixin, BufferListMixin, nn.Module): + """PyTorch module that implements the WinClip model for image anomaly detection. + + Args: + class_name (str, optional): The name of the object class used in the prompt ensemble. + Defaults to ``None``. + reference_images (torch.Tensor, optional): Tensor of shape ``(K, C, H, W)`` containing the reference images. + Defaults to ``None``. + scales (tuple[int], optional): The scales of the sliding windows used for multi-scale anomaly detection. + Defaults to ``(2, 3)``. + apply_transform (bool, optional): Whether to apply the default CLIP transform to the input images. + Defaults to ``False``. + + Attributes: + clip (CLIP): The CLIP model used for image and text encoding. + grid_size (tuple[int]): The size of the feature map grid. + k_shot (int): The number of reference images used for few-shot anomaly detection. + scales (tuple[int]): The scales of the sliding windows used for multi-scale anomaly detection. + masks (list[torch.Tensor] | None): The masks representing the sliding window locations. + _text_embeddings (torch.Tensor | None): The text embeddings for the compositional prompt ensemble. + _visual_embeddings (list[torch.Tensor] | None): The multi-scale embeddings for the reference images. + _patch_embeddings (torch.Tensor | None): The patch embeddings for the reference images. + """ + + def __init__( + self, + class_name: str | None = None, + reference_images: torch.Tensor | None = None, + scales: tuple = (2, 3), + apply_transform: bool = False, + ) -> None: + super().__init__() + self.backbone = BACKBONE + self.pretrained = PRETRAINED + self.temperature = TEMPERATURE + self.class_name = class_name + self.reference_images = reference_images + self.scales = scales + self.apply_transform = apply_transform + self.k_shot = 0 + + # initialize CLIP model + self.clip, _, self._transform = open_clip.create_model_and_transforms(self.backbone, pretrained=self.pretrained) + self.clip.visual.output_tokens = True + self.grid_size = self.clip.visual.grid_size + + # register buffers + self.register_buffer_list("masks", self._generate_masks(), persistent=False) # no need to save masks + self.register_buffer("_text_embeddings", torch.empty(0)) + self.register_buffer_list("_visual_embeddings", [torch.empty(0) for _ in self.scales]) + self.register_buffer("_patch_embeddings", torch.empty(0)) + + # setup + self.setup(class_name, reference_images) + + def setup(self, class_name: str | None = None, reference_images: torch.Tensor | None = None) -> None: + """Setup WinCLIP. + + WinCLIP's setup stage consists of collecting the text and visual embeddings used during inference. The + following steps are performed, depending on the arguments passed to the model: + - Collect text embeddings for zero-shot inference. + - Collect reference images for few-shot inference. + The k_shot attribute is updated based on the number of reference images. + + The setup method is called internally by the constructor. However, it can also be called manually to update the + text and visual embeddings after the model has been initialized. + + Args: + class_name (str): The name of the object class used in the prompt ensemble. + reference_images (torch.Tensor): Tensor of shape ``(batch_size, C, H, W)`` containing the reference images. + + Examples: + >>> model = WinClipModel() + >>> model.setup("transistor") + >>> model.text_embeddings.shape + torch.Size([2, 640]) + + >>> ref_images = torch.rand(2, 3, 240, 240) + >>> model = WinClipModel() + >>> model.setup("transistor", ref_images) + >>> model.k_shot + 2 + >>> model.visual_embeddings[0].shape + torch.Size([2, 196, 640]) + + >>> model = WinClipModel("transistor") + >>> model.k_shot + 0 + >>> model.setup(reference_images=ref_images) + >>> model.k_shot + 2 + + >>> model = WinClipModel(class_name="transistor", reference_images=ref_images) + >>> model.text_embeddings.shape + torch.Size([2, 640]) + >>> model.visual_embeddings[0].shape + torch.Size([2, 196, 640]) + """ + # update class name and text embeddings + self.class_name = class_name or self.class_name + if self.class_name is not None: + self._collect_text_embeddings(self.class_name) + # update reference images, k_shot and visual embeddings + self.reference_images = reference_images if reference_images is not None else self.reference_images + if self.reference_images is not None: + self.k_shot = self.reference_images.shape[0] # update k_shot based on number of reference images + self._collect_visual_embeddings(self.reference_images) + + def encode_image(self, batch: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: + """Encode the batch of images to obtain image embeddings, window embeddings, and patch embeddings. + + The image embeddings and patch embeddings are obtained by passing the batch of images through the model. The + window embeddings are obtained by masking the feature map and passing it through the transformer. A forward hook + is used to retrieve the intermediate feature map and share computation between the image and window embeddings. + + Args: + batch (torch.Tensor): Batch of input images of shape ``(N, C, H, W)``. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: A tuple containing the image embeddings, + window embeddings, and patch embeddings respectively. + + Examples: + >>> model = WinClipModel() + >>> model.prepare_masks() + >>> batch = torch.rand((1, 3, 240, 240)) + >>> image_embeddings, window_embeddings, patch_embeddings = model.encode_image(batch) + >>> image_embeddings.shape + torch.Size([1, 640]) + >>> [embedding.shape for embedding in window_embeddings] + [torch.Size([1, 196, 640]), torch.Size([1, 169, 640])] + >>> patch_embeddings.shape + torch.Size([1, 225, 896]) + """ + # apply transform if needed + if self.apply_transform: + batch = torch.stack([self.transform(image) for image in batch]) + + # register hook to retrieve intermediate feature map + outputs = {} + + def get_feature_map(name: str) -> Callable: + def hook(_model: Identity, inputs: tuple[torch.Tensor,], _outputs: torch.Tensor) -> None: + del _model, _outputs + outputs[name] = inputs[0].detach() + + return hook + + # register hook to get the intermediate tokens of the transformer + self.clip.visual.patch_dropout.register_forward_hook(get_feature_map("feature_map")) + + # get image and patch embeddings + image_embeddings, patch_embeddings = self.clip.encode_image(batch) + + # get window embeddings + feature_map = outputs["feature_map"] + window_embeddings = [self._get_window_embeddings(feature_map, masks) for masks in self.masks] + + return ( + image_embeddings, + window_embeddings, + patch_embeddings, + ) + + def _get_window_embeddings(self, feature_map: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: + """Computes the embeddings for each window in the feature map using the given masks. + + Args: + feature_map (torch.Tensor): The input feature map of shape ``(n_batches, n_patches, dimensionality)``. + masks (torch.Tensor): Masks of shape ``(kernel_size, n_masks)`` representing the sliding window locations. + + Returns: + torch.Tensor: The embeddings for each sliding window location. + """ + batch_size = feature_map.shape[0] + n_masks = masks.shape[1] + + # prepend zero index for class embeddings + class_index = torch.zeros(1, n_masks, dtype=int).to(feature_map.device) + masks = torch.cat((class_index, masks + 1)).T # +1 to account for class index + # apply masks to feature map + masked = torch.cat([torch.index_select(feature_map, 1, mask) for mask in masks]) + + # finish forward pass on masked features + masked = self.clip.visual.patch_dropout(masked) + masked = self.clip.visual.ln_pre(masked) + + masked = masked.permute(1, 0, 2) # NLD -> LND + masked = self.clip.visual.transformer(masked) + masked = masked.permute(1, 0, 2) # LND -> NLD + + masked = self.clip.visual.ln_post(masked) + pooled, _ = self.clip.visual._global_pool(masked) # noqa: SLF001 + + if self.clip.visual.proj is not None: + pooled = pooled @ self.clip.visual.proj + + return pooled.reshape((n_masks, batch_size, -1)).permute(1, 0, 2) + + def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward-pass through the model to obtain image and pixel scores. + + Args: + batch (torch.Tensor): Batch of input images of shape ``(batch_size, C, H, W)``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing the image scores and pixel scores. + """ + image_embeddings, window_embeddings, patch_embeddings = self.encode_image(batch) + + # get zero-shot scores + image_scores = class_scores(image_embeddings, self.text_embeddings, self.temperature, target_class=1) + multi_scale_scores = self._compute_zero_shot_scores(image_scores, window_embeddings) + + # get few-shot scores + if self.k_shot: + few_shot_scores = self._compute_few_shot_scores(patch_embeddings, window_embeddings) + multi_scale_scores = (multi_scale_scores + few_shot_scores) / 2 + image_scores = (image_scores + few_shot_scores.amax(dim=(-2, -1))) / 2 + + # reshape to image dimensions + pixel_scores = nn.functional.interpolate( + multi_scale_scores.unsqueeze(1), + size=batch.shape[-2:], + mode="bilinear", + ) + return image_scores, pixel_scores.squeeze(1) + + def _compute_zero_shot_scores( + self, + image_scores: torch.Tensor, + window_embeddings: list[torch.Tensor], + ) -> torch.Tensor: + """Compute the multi-scale anomaly score maps based on the text embeddings. + + Each window embedding is compared to the text embeddings to obtain a similarity score for each window. Harmonic + averaging is then used to aggregate the scores for each window into a single score map for each scale. Finally, + the score maps are combined into a single multi-scale score map by aggregating across scales. + + Args: + image_scores (torch.Tensor): Tensor of shape ``(batch_size)`` representing the full image scores. + window_embeddings (list[torch.Tensor]): List of tensors of shape ``(batch_size, n_windows, n_features)`` + representing the embeddings for each sliding window location. + + Returns: + torch.Tensor: Tensor of shape ``(batch_size, H, W)`` representing the 0-shot scores for each patch location. + """ + # image scores are added to represent the full image scale + multi_scale_scores = [image_scores.view(-1, 1, 1).repeat(1, self.grid_size[0], self.grid_size[1])] + # add aggregated scores for each scale + for window_embedding, mask in zip(window_embeddings, self.masks, strict=True): + scores = class_scores(window_embedding, self.text_embeddings, self.temperature, target_class=1) + multi_scale_scores.append(harmonic_aggregation(scores, self.grid_size, mask)) + # aggregate scores across scales + return (len(self.scales) + 1) / (1 / torch.stack(multi_scale_scores)).sum(dim=0) + + def _compute_few_shot_scores( + self, + patch_embeddings: torch.Tensor, + window_embeddings: list[torch.Tensor], + ) -> torch.Tensor: + """Compute the multi-scale anomaly score maps based on the reference image embeddings. + + Visual association scores are computed between the extracted embeddings and the reference image embeddings for + each scale. The window-level scores are additionally aggregated into a single score map for each scale using + harmonic averaging. The final score maps are obtained by averaging across scales. + + Args: + patch_embeddings (torch.Tensor): Full-scale patch embeddings of shape + ``(batch_size, n_patches, n_features)``. + window_embeddings (list[torch.Tensor]): List of tensors of shape ``(batch_size, n_windows, n_features)`` + representing the embeddings for each sliding window location. + + Returns: + torch.Tensor: Tensor of shape ``(batch_size, H, W)`` representing the few-shot scores for each patch + location. + """ + multi_scale_scores = [ + visual_association_score(patch_embeddings, self.patch_embeddings).reshape((-1, *self.grid_size)), + ] + for window_embedding, reference_embedding, mask in zip( + window_embeddings, + self.visual_embeddings, + self.masks, + strict=True, + ): + scores = visual_association_score(window_embedding, reference_embedding) + multi_scale_scores.append(harmonic_aggregation(scores, self.grid_size, mask)) + + return torch.stack(multi_scale_scores).mean(dim=0) + + def _collect_text_embeddings(self, class_name: str) -> None: + """Collect text embeddings for the object class using a compositional prompt ensemble. + + First, an ensemble of normal and anomalous prompts is created based on the name of the object class. The + prompt ensembles are then tokenized and encoded to obtain prompt embeddings. The prompt embeddings are + averaged to obtain a single text embedding for the object class. These final text embeddings are stored in + the model to be used during inference. + + Args: + class_name (str): The name of the object class used in the prompt ensemble. + """ + # collect prompt ensemble + normal_prompts, anomalous_prompts = create_prompt_ensemble(class_name) + # tokenize prompts + normal_tokens = tokenize(normal_prompts) + anomalous_tokens = tokenize(anomalous_prompts) + # encode tokens to obtain prompt embeddings + with torch.no_grad(): + normal_embeddings = self.clip.encode_text(normal_tokens) + anomalous_embeddings = self.clip.encode_text(anomalous_tokens) + # average prompt embeddings + normal_embeddings = torch.mean(normal_embeddings, dim=0, keepdim=True) + anomalous_embeddings = torch.mean(anomalous_embeddings, dim=0, keepdim=True) + # concatenate and store + text_embeddings = torch.cat((normal_embeddings, anomalous_embeddings)) + self._text_embeddings = text_embeddings + + def _collect_visual_embeddings(self, images: torch.Tensor) -> None: + """Collect visual embeddings based on a set of normal reference images. + + Args: + images (torch.Tensor): Tensor of shape ``(K, C, H, W)`` containing the reference images. + """ + with torch.no_grad(): + _, self._visual_embeddings, self._patch_embeddings = self.encode_image(images) + + def _generate_masks(self) -> list[torch.Tensor]: + """Prepare a set of masks that operate as multi-scale sliding windows. + + For each of the scales, a set of masks is created that select patches from the feature map. Each mask represents + a sliding window location in the pixel domain. The masks are stored in the model to be used during inference. + + Returns: + list[torch.Tensor]: A list of tensors of shape ``(n_patches_per_mask, n_masks)`` representing the sliding + window locations for each scale. + """ + return [make_masks(self.grid_size, scale, 1) for scale in self.scales] + + @property + def transform(self) -> Compose: + """The transform used by the model. + + To obtain the transforms, we retrieve the transforms from the clip backbone. Since the original transforms are + intended for PIL images, we prepend a ToPILImage transform to the list of transforms. + """ + transforms = copy(self._transform.transforms) + transforms.insert(0, ToPILImage()) + return Compose(transforms) + + @property + def text_embeddings(self) -> torch.Tensor: + """The text embeddings used by the model.""" + if self._text_embeddings.numel() == 0: + msg = "Text embeddings have not been collected. Pass a class name to the model using ``setup``." + raise RuntimeError(msg) + return self._text_embeddings + + @property + def visual_embeddings(self) -> list[torch.Tensor]: + """The visual embeddings used by the model.""" + if self._visual_embeddings[0].numel() == 0: + msg = "Visual embeddings have not been collected. Pass some reference images to the model using ``setup``." + raise RuntimeError(msg) + return self._visual_embeddings + + @property + def patch_embeddings(self) -> torch.Tensor: + """The patch embeddings used by the model.""" + if self._patch_embeddings.numel() == 0: + msg = "Patch embeddings have not been collected. Pass some reference images to the model using ``setup``." + raise RuntimeError(msg) + return self._patch_embeddings diff --git a/anomalib/models/image/winclip/utils.py b/anomalib/models/image/winclip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1612be077ccd366a8e3c80ef1d724a94e79602e7 --- /dev/null +++ b/anomalib/models/image/winclip/utils.py @@ -0,0 +1,211 @@ +"""WinCLIP utils.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn + + +def cosine_similarity(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + """Compute pairwise cosine similarity matrix between two tensors. + + Computes the cosine similarity between all pairs of vectors in the two tensors. + + Args: + input1 (torch.Tensor): Input tensor of shape ``(N, D)`` or ``(B, N, D)``. + input2 (torch.Tensor): Input tensor of shape ``(M, D)`` or ``(B, M, D)``. + + Returns: + torch.Tensor: Cosine similarity matrix of shape ``(N, M)`` or ``(B, N, M)``. + + Examples: + >>> input1 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) + >>> input2 = torch.tensor([[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) + >>> cosine_similarity(input1, input2) + tensor([[[0.0000, 0.7071], + [1.0000, 0.7071]]]) + + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(200, 128) + >>> cosine_similarity(input1, input2).shape + torch.Size([100, 200]) + + >>> input1 = torch.randn(10, 100, 128) + >>> input2 = torch.randn(10, 200, 128) + >>> cosine_similarity(input1, input2).shape + torch.Size([10, 100, 200]) + """ + ndim = input1.ndim + input1 = input1.unsqueeze(0) if input1.ndim == 2 else input1 + input2 = input2.repeat(input1.shape[0], 1, 1) if input2.ndim == 2 else input2 + + input1_norm = nn.functional.normalize(input1, p=2, dim=-1) + input2_norm = nn.functional.normalize(input2, p=2, dim=-1) + similarity = torch.bmm(input1_norm, input2_norm.transpose(-2, -1)) + if ndim == 2: + return similarity.squeeze(0) + return similarity + + +def class_scores( + image_embeddings: torch.Tensor, + text_embeddings: torch.Tensor, + temperature: float = 1.0, + target_class: int | None = None, +) -> torch.Tensor: + """Compute class scores between a set of N image embeddings and a set of M text embeddings. + + Each text embedding represents the embedding of a prompt for a specific class. By computing the cosine similarity + between each image embedding and each text embedding, we obtain a similarity matrix of shape (N, M). This matrix is + then used to compute the confidence scores for each class by scaling by a temperature parameter and applying the + softmax function (Equation (1) in the WinCLIP paper). + + Args: + image_embeddings (torch.Tensor): Image embedding matrix of shape ``(N, D)`` or ``(B, N, D)``. + text_embeddings (torch.Tensor): Text embedding matrix of shape ``(M, D)`` or ``(B, M, D)``. + temperature (float): Temperature hyperparameter. + target_class (int): Index of the target class. If None, the scores for all classes are returned. + + Returns: + torch.Tensor: Similarity score of shape ``(N, M)`` or ``(B, N, M)``. + + Examples: + >>> image_embeddings = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) + >>> text_embeddings = torch.tensor([[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) + >>> class_scores(image_embeddings, text_embeddings) + tensor([[0.3302, 0.6698], + [0.5727, 0.4273]]) + + >>> image_embeddings = torch.randn(100, 128) + >>> text_embeddings = torch.randn(200, 128) + >>> class_scores(image_embeddings, text_embeddings).shape + torch.Size([100, 200]) + + >>> image_embeddings = torch.randn(10, 100, 128) + >>> text_embeddings = torch.randn(10, 200, 128) + >>> class_scores(image_embeddings, text_embeddings).shape + torch.Size([10, 100, 200]) + + >>> image_embeddings = torch.randn(10, 100, 128) + >>> text_embeddings = torch.randn(10, 200, 128) + >>> class_scores(image_embeddings, text_embeddings, target_class=0).shape + torch.Size([10, 100]) + """ + scores = (cosine_similarity(image_embeddings, text_embeddings) / temperature).softmax(dim=-1) + if target_class is not None: + return scores[..., target_class] + return scores + + +def harmonic_aggregation(window_scores: torch.Tensor, output_size: tuple, masks: torch.Tensor) -> torch.Tensor: + """Perform harmonic aggregation on window scores. + + Computes a single score for each patch location by aggregating the scores of all windows that cover the patch. + Scores are aggregated using the harmonic mean. + + Args: + window_scores (torch.Tensor): Tensor of shape ``(batch_size, n_masks)`` representing the scores for each sliding + window location. + output_size (tuple): Tuple of integers representing the output size ``(H, W)``. + masks (torch.Tensor): Tensor of shape ``(n_patches_per_mask, n_masks)`` representing the masks. Each mask is + set of indices indicating which patches are covered by the mask. + + Returns: + torch.Tensor: Tensor of shape ``(batch_size, H, W)```` representing the aggregated scores. + + Examples: + >>> # example for a 3x3 patch grid with 4 sliding windows of size 2x2 + >>> window_scores = torch.tensor([[1.0, 0.75, 0.5, 0.25]]) + >>> output_size = (3, 3) + >>> masks = torch.Tensor([[0, 1, 3, 4], + [1, 2, 4, 5], + [3, 4, 6, 7], + [4, 5, 7, 8]]) + >>> harmonic_aggregation(window_scores, output_size, masks) + tensor([[[1.0000, 0.8571, 0.7500], + [0.6667, 0.4800, 0.3750], + [0.5000, 0.3333, 0.2500]]]) + """ + batch_size = window_scores.shape[0] + height, width = output_size + + scores = [] + for idx in range(height * width): + patch_mask = torch.any(masks == idx, dim=0) # boolean tensor indicating which masks contain the patch + scores.append(sum(patch_mask) / (1 / window_scores.T[patch_mask]).sum(dim=0)) + + return torch.stack(scores).T.reshape(batch_size, height, width).nan_to_num(posinf=0.0) + + +def visual_association_score(embeddings: torch.Tensor, reference_embeddings: torch.Tensor) -> torch.Tensor: + """Compute visual association scores between a set of embeddings and a set of reference embeddings. + + Returns a visual association score for each patch location in the inputs. The visual association score is the + minimum cosine distance between each embedding and the reference embeddings. Equation (4) in the paper. + + Args: + embeddings (torch.Tensor): Tensor of shape ``(batch_size, n_patches, dimensionality)`` representing the + embeddings. + reference_embeddings (torch.Tensor): Tensor of shape ``(n_reference_embeddings, n_patches, dimensionality)`` + representing the reference embeddings. + + Returns: + torch.Tensor: Tensor of shape ``(batch_size, n_patches)`` representing the visual association scores. + + Examples: + >>> embeddings = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]) + >>> reference_embeddings = torch.tensor([[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]]) + >>> visual_association_score(embeddings, reference_embeddings) + tensor([[0.1464, 0.0000]]) + + >>> embeddings = torch.randn(10, 100, 128) + >>> reference_embeddings = torch.randn(2, 100, 128) + >>> visual_association_score(embeddings, reference_embeddings).shape + torch.Size([10, 100]) + """ + reference_embeddings = reference_embeddings.reshape(-1, embeddings.shape[-1]) + scores = cosine_similarity(embeddings, reference_embeddings) + return (1 - scores).min(dim=-1)[0] / 2 + + +def make_masks(grid_size: tuple[int, int], kernel_size: int, stride: int = 1) -> torch.Tensor: + """Make a set of masks to select patches from a feature map in a sliding window fashion. + + Each column in the returned tensor represents a mask. Each mask is a set of indices indicating which patches are + covered by the mask. The number of masks is equal to the number of sliding windows that fit in the feature map. + + Args: + grid_size (tuple[int, int]): The shape of the feature map. + kernel_size (int): The size of the kernel in number of patches. + stride (int): The size of the stride in number of patches. + + Returns: + torch.Tensor: Set of masks of shape ``(n_patches_per_mask, n_masks)``. + + Examples: + >>> make_masks((3, 3), 2) + tensor([[0, 1, 3, 4], + [1, 2, 4, 5], + [3, 4, 6, 7], + [4, 5, 7, 8]], dtype=torch.int32) + + >>> make_masks((4, 4), 2) + tensor([[ 0, 1, 2, 4, 5, 6, 8, 9, 10], + [ 1, 2, 3, 5, 6, 7, 9, 10, 11], + [ 4, 5, 6, 8, 9, 10, 12, 13, 14], + [ 5, 6, 7, 9, 10, 11, 13, 14, 15]], dtype=torch.int32) + + >>> make_masks((4, 4), 2, stride=2) + tensor([[ 0, 2, 8, 10], + [ 1, 3, 9, 11], + [ 4, 6, 12, 14], + [ 5, 7, 13, 15]], dtype=torch.int32) + """ + if any(dim < kernel_size for dim in grid_size): + msg = "Each dimension of the grid size must be greater than or equal to the kernel size. Got grid size {} and \ + kernel size {}.".format(grid_size, kernel_size) + raise ValueError(msg) + height, width = grid_size + grid = torch.arange(height * width).reshape(1, height, width) + return nn.functional.unfold(grid.float(), kernel_size=kernel_size, stride=stride).int() diff --git a/anomalib/models/video/README.md b/anomalib/models/video/README.md new file mode 100644 index 0000000000000000000000000000000000000000..264a8de52f0ad674a69851bec222189cf1c1bcd4 --- /dev/null +++ b/anomalib/models/video/README.md @@ -0,0 +1,40 @@ +# Anomalib Video Models + +## 📝 Description + +This sub-package contains the models for handling video datasets in anomalib. + +The anomalib.models.video subpackage provides: + +- Classes and functions to define video anomaly models. +- Models for video-based anomaly classification, detection or segmentation. + +## ⚠️ Note + +The models defined here are designed specifically to handle video datasets +These models contain spatio-temporal layers that are not present in the image +models. + +## 💡 Examples + +The following example shows how to use the AiVad model to train on the Avenue dataset. + +
+Training the AiVad model on Avenue video dataset + +```python +# Import the necessary modules +from anomalib.data import Avenue +from anomalib.models import AiVad +from anomalib.engine import Engine + +# Load the avenue dataset, model and engine. +datamodule = Avenue() +model = AiVad() +engine = Engine() + +# Train the model +engine.train(model, datamodule) +``` + +
diff --git a/anomalib/models/video/__init__.py b/anomalib/models/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae952f0e3082f574257e007e416ffba61f9f8b6f --- /dev/null +++ b/anomalib/models/video/__init__.py @@ -0,0 +1,8 @@ +"""Anomalib Video Models.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .ai_vad import AiVad + +__all__ = ["AiVad"] diff --git a/anomalib/models/video/ai_vad/__init__.py b/anomalib/models/video/ai_vad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..740636009b17241e490ddf1376ca965d2817b431 --- /dev/null +++ b/anomalib/models/video/ai_vad/__init__.py @@ -0,0 +1,13 @@ +"""Implementatation of the AI-VAD Model. + +AI-VAD: Accurate and Interpretable Video Anomaly Detection + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import AiVad + +__all__ = ["AiVad"] diff --git a/anomalib/models/video/ai_vad/clip/LICENSE b/anomalib/models/video/ai_vad/clip/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c123b69334717d178daa674c2d08e3383fe36134 --- /dev/null +++ b/anomalib/models/video/ai_vad/clip/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/anomalib/models/video/ai_vad/clip/__init__.py b/anomalib/models/video/ai_vad/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d32c9a1b3b3119882e96196458b4e0563f241dc4 --- /dev/null +++ b/anomalib/models/video/ai_vad/clip/__init__.py @@ -0,0 +1,4 @@ +"""CLIP Implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/anomalib/models/video/ai_vad/clip/clip.py b/anomalib/models/video/ai_vad/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..b094e990218054e07d14102dfe10a43ad9c19ee7 --- /dev/null +++ b/anomalib/models/video/ai_vad/clip/clip.py @@ -0,0 +1,226 @@ +# mypy: ignore-errors +# ruff: noqa + +# Original Code +# https://github.com/openai/CLIP. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import logging +import os +from typing import List, Union +from urllib.parse import urlparse + +import requests +import torch +from PIL import Image +from pkg_resources import packaging +from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor +from tqdm import tqdm + +logger = logging.getLogger(__name__) +from .model import build_model + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + msg = "PyTorch version 1.7.1 or higher is recommended" + logger.warn(msg) + +__all__ = ["available_models", "load"] + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _verify_checksum(file_path: str, url: str) -> bool: + expected_sha256 = url.split("/")[-2] + sha256_hash = hashlib.sha256() + + with open(file_path, "rb") as file: + for chunk in iter(lambda: file.read(4096), b""): + sha256_hash.update(chunk) + + file_hash = sha256_hash.hexdigest() + + return file_hash == expected_sha256 + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(urlparse(url).path) + download_target = os.path.join(root, filename) + + if os.path.exists(download_target): + if not os.path.isfile(download_target): + raise FileExistsError(f"{download_target} exists and is not a regular file") + if _verify_checksum(download_target, url): + return download_target + + logger.warning("%s exists, but the checksum does not match; re-downloading the file", download_target) + os.remove(download_target) + + response = requests.get(url, stream=True, timeout=10.0) # Timeout is for bandit security linter + response.raise_for_status() + + total_size = int(response.headers.get("Content-Length", 0)) + + with open(download_target, "wb") as file, tqdm( + total=total_size, ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 + ) as loop: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + file.write(chunk) + loop.update(len(chunk)) + + if not _verify_checksum(download_target, url): + raise RuntimeError("Model has been downloaded but the checksum does not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose( + [ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + download_root: str = None, +): + """Load a CLIP model + + Args: + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns: + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, "rb") as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + msg = f"File {model_path} is not a JIT archive. Loading as a state dict instead" + logger.warn(msg) + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) diff --git a/anomalib/models/video/ai_vad/clip/model.py b/anomalib/models/video/ai_vad/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9f23afa8fc1d5a13b9eda7510ae77189b417a70d --- /dev/null +++ b/anomalib/models/video/ai_vad/clip/model.py @@ -0,0 +1,479 @@ +# mypy: ignore-errors +# ruff: noqa + +# Original Code +# https://github.com/openai/CLIP. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + if output_width**2 + 1 != state_dict["visual.attnpool.positional_embedding"].shape[0]: + msg = "Assertion failed: output_width**2 + 1 != state_dict['visual.attnpool.positional_embedding'].shape[0]" + raise ValueError(msg) + + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/anomalib/models/video/ai_vad/density.py b/anomalib/models/video/ai_vad/density.py new file mode 100644 index 0000000000000000000000000000000000000000..857c80cf6fb8c1081a9c6045437b6493392aab5d --- /dev/null +++ b/anomalib/models/video/ai_vad/density.py @@ -0,0 +1,336 @@ +"""Density estimation module for AI-VAD model implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from abc import ABC, abstractmethod + +import torch +from torch import Tensor, nn + +from anomalib.metrics.min_max import MinMax +from anomalib.models.components.base import DynamicBufferMixin +from anomalib.models.components.cluster.gmm import GaussianMixture + +from .features import FeatureType + + +class BaseDensityEstimator(nn.Module, ABC): + """Base density estimator.""" + + @abstractmethod + def update(self, features: dict[FeatureType, torch.Tensor] | torch.Tensor, group: str | None = None) -> None: + """Update the density model with a new set of features.""" + raise NotImplementedError + + @abstractmethod + def predict( + self, + features: dict[FeatureType, torch.Tensor] | torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Predict the density of a set of features.""" + raise NotImplementedError + + @abstractmethod + def fit(self) -> None: + """Compose model using collected features.""" + raise NotImplementedError + + def forward( + self, + features: dict[FeatureType, torch.Tensor] | torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: + """Update or predict depending on training status.""" + if self.training: + self.update(features) + return None + return self.predict(features) + + +class CombinedDensityEstimator(BaseDensityEstimator): + """Density estimator for AI-VAD. + + Combines density estimators for the different feature types included in the model. + + Args: + use_pose_features (bool): Flag indicating if pose features should be used. + Defaults to ``True``. + use_deep_features (bool): Flag indicating if deep features should be used. + Defaults to ``True``. + use_velocity_features (bool): Flag indicating if velocity features should be used. + Defaults to ``False``. + n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. + Defaults to ``1``. + n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. + Defaults to ``1``. + n_components_velocity (int): Number of components used by GMM density estimation for velocity features. + Defaults to ``5``. + """ + + def __init__( + self, + use_pose_features: bool = True, + use_deep_features: bool = True, + use_velocity_features: bool = False, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + n_components_velocity: int = 5, + ) -> None: + super().__init__() + + self.use_pose_features = use_pose_features + self.use_deep_features = use_deep_features + self.use_velocity_features = use_velocity_features + + if self.use_velocity_features: + self.velocity_estimator = GMMEstimator(n_components=n_components_velocity) + if self.use_deep_features: + self.appearance_estimator = GroupedKNNEstimator(n_neighbors_deep) + if self.use_pose_features: + self.pose_estimator = GroupedKNNEstimator(n_neighbors=n_neighbors_pose) + if not any((use_pose_features, use_deep_features, use_velocity_features)): + msg = "At least one feature stream must be enabled." + raise ValueError(msg) + + def update(self, features: dict[FeatureType, torch.Tensor], group: str | None = None) -> None: + """Update the density estimators for the different feature types. + + Args: + features (dict[FeatureType, torch.Tensor]): Dictionary containing extracted features for a single frame. + group (str): Identifier of the video from which the frame was sampled. Used for grouped density estimation. + """ + if self.use_velocity_features: + self.velocity_estimator.update(features[FeatureType.VELOCITY]) + if self.use_deep_features: + self.appearance_estimator.update(features[FeatureType.DEEP], group=group) + if self.use_pose_features: + self.pose_estimator.update(features[FeatureType.POSE], group=group) + + def fit(self) -> None: + """Fit the density estimation models on the collected features.""" + if self.use_velocity_features: + self.velocity_estimator.fit() + if self.use_deep_features: + self.appearance_estimator.fit() + if self.use_pose_features: + self.pose_estimator.fit() + + def predict(self, features: dict[FeatureType, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """Predict the region- and image-level anomaly scores for an image based on a set of features. + + Args: + features (dict[Tensor]): Dictionary containing extracted features for a single frame. + + Returns: + Tensor: Region-level anomaly scores for all regions withing the frame. + Tensor: Frame-level anomaly score for the frame. + """ + n_regions = next(iter(features.values())).shape[0] + device = next(iter(features.values())).device + region_scores = torch.zeros(n_regions).to(device) + image_score = 0 + if self.use_velocity_features and features[FeatureType.VELOCITY].numel(): + velocity_scores = self.velocity_estimator.predict(features[FeatureType.VELOCITY]) + region_scores += velocity_scores + image_score += velocity_scores.max() + if self.use_deep_features and features[FeatureType.DEEP].numel(): + deep_scores = self.appearance_estimator.predict(features[FeatureType.DEEP]) + region_scores += deep_scores + image_score += deep_scores.max() + if self.use_pose_features and features[FeatureType.POSE].numel(): + pose_scores = self.pose_estimator.predict(features[FeatureType.POSE]) + region_scores += pose_scores + image_score += pose_scores.max() + return region_scores, image_score + + +class GroupedKNNEstimator(DynamicBufferMixin, BaseDensityEstimator): + """Grouped KNN density estimator. + + Keeps track of the group (e.g. video id) from which the features were sampled for normalization purposes. + + Args: + n_neighbors (int): Number of neighbors used in KNN search. + """ + + def __init__(self, n_neighbors: int) -> None: + super().__init__() + + self.n_neighbors = n_neighbors + self.feature_collection: dict[str, list[torch.Tensor]] = {} + self.group_index: dict[str, int] = {} + self.normalization_statistics = MinMax() + + self.register_buffer("memory_bank", Tensor()) + self.memory_bank: torch.Tensor = Tensor() + + def update(self, features: torch.Tensor, group: str | None = None) -> None: + """Update the internal feature bank while keeping track of the group. + + Args: + features (torch.Tensor): Feature vectors extracted from a video frame. + group (str): Identifier of the group (video) from which the frame was sampled. + """ + group = group or "default" + + if group in self.feature_collection: + self.feature_collection[group].append(features) + else: + self.feature_collection[group] = [features] + + def fit(self) -> None: + """Fit the KNN model by stacking the feature vectors and computing the normalization statistics.""" + # stack the collected features group-wise + feature_collection = {key: torch.vstack(value) for key, value in self.feature_collection.items()} + # assign memory bank, group index and group names + self.memory_bank = torch.vstack(list(feature_collection.values())) + self.group_index = torch.repeat_interleave( + Tensor([features.shape[0] for features in feature_collection.values()]).int(), + ) + self.group_names = list(feature_collection.keys()) + self._compute_normalization_statistics(feature_collection) + # delete the feature collection to free up memory + del self.feature_collection + + def predict( + self, + features: torch.Tensor, + group: str | None = None, + n_neighbors: int = 1, + normalize: bool = True, + ) -> torch.Tensor: + """Predict the (normalized) density for a set of features. + + Args: + features (torch.Tensor): Input features that will be compared to the density model. + group (str, optional): Group (video id) from which the features originate. If passed, all features of the + same group in the memory bank will be excluded from the density estimation. + Defaults to ``None``. + n_neighbors (int): Number of neighbors used in the KNN search. + Defaults to ``1``. + normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. + Defatuls to ``True``. + + Returns: + Tensor: Mean (normalized) distances of input feature vectors to k nearest neighbors in feature bank. + """ + n_neighbors = n_neighbors or self.n_neighbors + + if group: + group_idx = self.group_names.index(group) + mem_bank = self.memory_bank[self.group_index != group_idx] + else: + mem_bank = self.memory_bank + + distances = self._nearest_neighbors(mem_bank, features, n_neighbors=n_neighbors) + + if normalize: + distances = self._normalize(distances) + + return distances.mean(axis=1) + + @staticmethod + def _nearest_neighbors(feature_bank: torch.Tensor, features: torch.Tensor, n_neighbors: int = 1) -> torch.Tensor: + """Perform the KNN search. + + Args: + feature_bank (torch.Tensor): Feature bank used for KNN search. + features (Ternsor): Input features. + n_neighbors (int): Number of neighbors used in KNN search. + + Returns: + Tensor: Distances between the input features and their K nearest neighbors in the feature bank. + """ + distances = torch.cdist(features, feature_bank, p=2.0) # euclidean norm + if n_neighbors == 1: + # when n_neighbors is 1, speed up computation by using min instead of topk + distances, _ = distances.min(1) + return distances.unsqueeze(1) + distances, _ = distances.topk(k=n_neighbors, largest=False, dim=1) + return distances + + def _compute_normalization_statistics(self, grouped_features: dict[str, Tensor]) -> None: + """Compute min-max normalization statistics while taking the group into account.""" + for group, features in grouped_features.items(): + distances = self.predict(features, group, normalize=False) + self.normalization_statistics.update(distances) + + self.normalization_statistics.compute() + + def _normalize(self, distances: torch.Tensor) -> torch.Tensor: + """Normalize distance predictions. + + Args: + distances (torch.Tensor): Distance tensor produced by KNN search. + + Returns: + Tensor: Normalized distances. + """ + return (distances - self.normalization_statistics.min) / ( + self.normalization_statistics.max - self.normalization_statistics.min + ) + + +class GMMEstimator(BaseDensityEstimator): + """Density estimation based on Gaussian Mixture Model. + + Args: + n_components (int): Number of components used in the GMM. + Defaults to ``2``. + """ + + def __init__(self, n_components: int = 2) -> None: + super().__init__() + + self.gmm = GaussianMixture(n_components=n_components) + self.memory_bank: list[torch.Tensor] | torch.Tensor = [] + + self.normalization_statistics = MinMax() + + def update(self, features: torch.Tensor, group: str | None = None) -> None: + """Update the feature bank.""" + del group + if isinstance(self.memory_bank, list): + self.memory_bank.append(features) + + def fit(self) -> None: + """Fit the GMM and compute normalization statistics.""" + self.memory_bank = torch.vstack(self.memory_bank) + self.gmm.fit(self.memory_bank) + self._compute_normalization_statistics() + + def predict(self, features: torch.Tensor, normalize: bool = True) -> torch.Tensor: + """Predict the density of a set of feature vectors. + + Args: + features (torch.Tensor): Input feature vectors. + normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. + Defaults to ``True``. + + Returns: + Tensor: Density scores of the input feature vectors. + """ + density = -self.gmm.score_samples(features) + if normalize: + density = self._normalize(density) + return density + + def _compute_normalization_statistics(self) -> None: + """Compute min-max normalization statistics over the feature bank.""" + training_scores = self.predict(self.memory_bank, normalize=False) + self.normalization_statistics.update(training_scores) + self.normalization_statistics.compute() + + def _normalize(self, density: torch.Tensor) -> torch.Tensor: + """Normalize distance predictions. + + Args: + density (torch.Tensor): Distance tensor produced by KNN search. + + Returns: + Tensor: Normalized distances. + """ + return (density - self.normalization_statistics.min) / ( + self.normalization_statistics.max - self.normalization_statistics.min + ) diff --git a/anomalib/models/video/ai_vad/features.py b/anomalib/models/video/ai_vad/features.py new file mode 100644 index 0000000000000000000000000000000000000000..296769f799df51391858fd71ad19a52fd39dec1b --- /dev/null +++ b/anomalib/models/video/ai_vad/features.py @@ -0,0 +1,259 @@ +"""Feature extraction module for AI-VAD model implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from enum import Enum + +import torch +from torch import nn +from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights, keypointrcnn_resnet50_fpn +from torchvision.models.detection.roi_heads import keypointrcnn_inference +from torchvision.ops import roi_align +from torchvision.transforms import Normalize + +from .clip import clip + + +class FeatureType(str, Enum): + """Names of the different feature streams used in AI-VAD.""" + + POSE = "pose" + VELOCITY = "velocity" + DEEP = "deep" + + +class FeatureExtractor(nn.Module): + """Feature extractor for AI-VAD. + + Args: + n_velocity_bins (int): Number of discrete bins used for velocity histogram features. + Defaults to ``8``. + use_velocity_features (bool): Flag indicating if velocity features should be used. + Defaults to ``True``. + use_pose_features (bool): Flag indicating if pose features should be used. + Defaults to ``True``. + use_deep_features (bool): Flag indicating if deep features should be used. + Defaults to ``True``. + """ + + def __init__( + self, + n_velocity_bins: int = 8, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + ) -> None: + super().__init__() + if not (use_velocity_features or use_pose_features or use_deep_features): + msg = "At least one feature stream must be enabled." + raise ValueError(msg) + + self.use_velocity_features = use_velocity_features + self.use_pose_features = use_pose_features + self.use_deep_features = use_deep_features + + self.deep_extractor = DeepExtractor() + self.velocity_extractor = VelocityExtractor(n_bins=n_velocity_bins) + self.pose_extractor = PoseExtractor() + + def forward( + self, + rgb_batch: torch.Tensor, + flow_batch: torch.Tensor, + regions: list[dict], + ) -> list[dict]: + """Forward pass through the feature extractor. + + Extract any combination of velocity, pose and deep features depending on configuration. + + Args: + rgb_batch (torch.Tensor): Batch of RGB images of shape (N, 3, H, W) + flow_batch (torch.Tensor): Batch of optical flow images of shape (N, 2, H, W) + regions (list[dict]): Region information per image in batch. + + Returns: + list[dict]: Feature dictionary per image in batch. + """ + batch_size = rgb_batch.shape[0] + + # convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2] + boxes_list = [batch_item["boxes"] for batch_item in regions] + indices = torch.repeat_interleave( + torch.arange(len(regions)), + torch.Tensor([boxes.shape[0] for boxes in boxes_list]).int(), + ) + boxes = torch.cat([indices.unsqueeze(1).to(rgb_batch.device), torch.cat(boxes_list)], dim=1) + + # Extract features + feature_dict = {} + if self.use_velocity_features: + velocity_features = self.velocity_extractor(flow_batch, boxes) + feature_dict[FeatureType.VELOCITY] = [velocity_features[indices == i] for i in range(batch_size)] + if self.use_pose_features: + pose_features = self.pose_extractor(rgb_batch, boxes_list) + feature_dict[FeatureType.POSE] = pose_features + if self.use_deep_features: + deep_features = self.deep_extractor(rgb_batch, boxes, batch_size) + feature_dict[FeatureType.DEEP] = [deep_features[indices == i] for i in range(batch_size)] + + # dict of lists to list of dicts + return [dict(zip(feature_dict, item, strict=True)) for item in zip(*feature_dict.values(), strict=True)] + + +class DeepExtractor(nn.Module): + """Deep feature extractor. + + Extracts the deep (appearance) features from the input regions. + """ + + def __init__(self) -> None: + super().__init__() + + self.encoder, _ = clip.load("ViT-B/16") + self.transform = Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + self.output_dim = self.encoder.visual.output_dim + + def forward(self, batch: torch.Tensor, boxes: torch.Tensor, batch_size: int) -> torch.Tensor: + """Extract deep features using CLIP encoder. + + Args: + batch (torch.Tensor): Batch of RGB input images of shape (N, 3, H, W) + boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). + First column indicates batch index of the bbox. + batch_size (int): Number of images in the batch. + + Returns: + Tensor: Deep feature tensor of shape (M, 512) + """ + rgb_regions = roi_align(batch, boxes, output_size=[224, 224]) + + batched_regions = torch.split(rgb_regions, batch_size) + batched_regions = [batch for batch in batched_regions if batch.numel() != 0] + with torch.no_grad(): + features = [self.encoder.encode_image(self.transform(batch)) for batch in batched_regions] + return torch.vstack(features).float() if len(features) else torch.empty(0, self.output_dim).to(batch.device) + + +class VelocityExtractor(nn.Module): + """Velocity feature extractor. + + Extracts histograms of optical flow magnitude and direction. + + Args: + n_bins (int): Number of direction bins used for the feature histograms. + """ + + def __init__(self, n_bins: int = 8) -> None: + super().__init__() + + self.n_bins = n_bins + + def forward(self, flows: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor: + """Extract velocioty features by filling a histogram. + + Args: + flows (torch.Tensor): Batch of optical flow images of shape (N, 2, H, W) + boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). + First column indicates batch index of the bbox. + + Returns: + Tensor: Velocity feature tensor of shape (M, n_bins) + """ + flow_regions = roi_align(flows, boxes, output_size=[224, 224]) + + # cartesian to polar + mag_batch = torch.linalg.norm(flow_regions, axis=1, ord=2) + theta_batch = torch.atan2(flow_regions[:, 0, ...], flow_regions[:, 1, ...]) + + # compute velocity histogram + velocity_histograms = [] + for mag, theta in zip(mag_batch, theta_batch, strict=True): + histogram_mag = torch.histogram( + input=theta.cpu(), + bins=self.n_bins, + range=(-torch.pi, torch.pi), + weight=mag.cpu(), + ).hist + histogram_counts = torch.histogram(input=theta.cpu(), bins=self.n_bins, range=(-torch.pi, torch.pi)).hist + final_histogram = torch.zeros_like(histogram_mag) + mask = histogram_counts != 0 + final_histogram[mask] = histogram_mag[mask] / histogram_counts[mask] + velocity_histograms.append(final_histogram) + + if len(velocity_histograms) == 0: + return torch.empty(0, self.n_bins).to(flows.device) + return torch.stack(velocity_histograms).to(flows.device) + + +class PoseExtractor(nn.Module): + """Pose feature extractor. + + Extracts pose features based on estimated body landmark keypoints. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT + model = keypointrcnn_resnet50_fpn(weights=weights) + self.model = model + self.transform = model.transform + self.backbone = model.backbone + self.roi_heads = model.roi_heads + + @staticmethod + def _post_process(keypoint_detections: list[dict]) -> list[torch.Tensor]: + """Convert keypoint predictions to 1D feature vectors. + + Post-processing consists of flattening and normalizing to bbox coordinates. + + Args: + keypoint_detections (list[dict]): Outputs of the keypoint extractor + + Returns: + list[torch.Tensor]: List of pose feature tensors for each image + """ + poses = [] + for detection in keypoint_detections: + boxes = detection["boxes"].unsqueeze(1) + keypoints = detection["keypoints"] + normalized_keypoints = (keypoints[..., :2] - boxes[..., :2]) / (boxes[..., 2:] - boxes[..., :2]) + length = normalized_keypoints.shape[-1] * normalized_keypoints.shape[-2] + poses.append(normalized_keypoints.reshape(normalized_keypoints.shape[0], length)) + return poses + + def forward(self, batch: torch.Tensor, boxes: torch.Tensor) -> list[torch.Tensor]: + """Extract pose features using a human keypoint estimation model. + + Args: + batch (torch.Tensor): Batch of RGB input images of shape (N, 3, H, W) + boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). + First column indicates batch index of the bbox. + + Returns: + list[torch.Tensor]: list of pose feature tensors for each image. + """ + images, _ = self.transform(batch) + features = self.backbone(images.tensors) + + image_sizes = [b.shape[-2:] for b in batch] + scales = [ + torch.Tensor(new) / torch.Tensor([orig[0], orig[1]]) + for orig, new in zip(image_sizes, images.image_sizes, strict=True) + ] + + boxes = [box * scale.repeat(2).to(box.device) for box, scale in zip(boxes, scales, strict=True)] + + keypoint_features = self.roi_heads.keypoint_roi_pool(features, boxes, images.image_sizes) + keypoint_features = self.roi_heads.keypoint_head(keypoint_features) + keypoint_logits = self.roi_heads.keypoint_predictor(keypoint_features) + keypoints_probs, _ = keypointrcnn_inference(keypoint_logits, boxes) + + keypoint_detections = self.transform.postprocess( + [{"keypoints": keypoints, "boxes": box} for keypoints, box in zip(keypoints_probs, boxes, strict=True)], + images.image_sizes, + image_sizes, + ) + return self._post_process(keypoint_detections) diff --git a/anomalib/models/video/ai_vad/flow.py b/anomalib/models/video/ai_vad/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..35d3e478fda8514297dfffe993df3556ddf31510 --- /dev/null +++ b/anomalib/models/video/ai_vad/flow.py @@ -0,0 +1,60 @@ +"""Optical Flow extraction module for AI-VAD implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torchvision.transforms.functional as F # noqa: N812 +from torch import nn +from torchvision.models.optical_flow import Raft_Large_Weights, raft_large + + +class FlowExtractor(nn.Module): + """Optical Flow extractor. + + Computes the pixel displacement between 2 consecutive frames from a video clip. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + weights = Raft_Large_Weights.DEFAULT + self.model = raft_large(weights=weights) + self.transforms = weights.transforms() + + def pre_process(self, first_frame: torch.Tensor, last_frame: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Resize inputs to dimensions required by backbone. + + Args: + first_frame (torch.Tensor): Starting frame of optical flow computation. + last_frame (torch.Tensor): Last frame of optical flow computation. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Preprocessed first and last frame. + """ + first_frame = F.resize(first_frame, size=[520, 960], antialias=False) + last_frame = F.resize(last_frame, size=[520, 960], antialias=False) + return self.transforms(first_frame, last_frame) + + def forward(self, first_frame: torch.Tensor, last_frame: torch.Tensor) -> torch.Tensor: + """Forward pass through the flow extractor. + + Args: + first_frame (torch.Tensor): Batch of starting frames of shape (N, 3, H, W). + last_frame (torch.Tensor): Batch of last frames of shape (N, 3, H, W). + + Returns: + Tensor: Estimated optical flow map of shape (N, 2, H, W). + """ + height, width = first_frame.shape[-2:] + + # preprocess batch + first_frame, last_frame = self.pre_process(first_frame, last_frame) + + # get flow maps + with torch.no_grad(): + flows = self.model(first_frame, last_frame)[-1] + + # convert back to original size + return F.resize(flows, [height, width], antialias=False) diff --git a/anomalib/models/video/ai_vad/lightning_model.py b/anomalib/models/video/ai_vad/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..55ac5a23dab9ac9025d8225f8e0542fbd88402b5 --- /dev/null +++ b/anomalib/models/video/ai_vad/lightning_model.py @@ -0,0 +1,166 @@ +"""Attribute-based Representations for Accurate and Interpretable Video Anomaly Detection. + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torchvision.transforms.v2 import Transform + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule, MemoryBankMixin + +from .torch_model import AiVadModel + +logger = logging.getLogger(__name__) + +__all__ = ["AiVad"] + + +class AiVad(MemoryBankMixin, AnomalyModule): + """AI-VAD: Attribute-based Representations for Accurate and Interpretable Video Anomaly Detection. + + Args: + box_score_thresh (float): Confidence threshold for bounding box predictions. + Defaults to ``0.7``. + persons_only (bool): When enabled, only regions labeled as person are included. + Defaults to ``False``. + min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. + Defaults to ``100``. + max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. + Defaults to ``0.65``. + enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between + consecutive frames. + Defaults to ``True``. + foreground_kernel_size (int): Gaussian kernel size used in foreground detection. + Defaults to ``3``. + foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground + detection. + Defaults to ``18``. + n_velocity_bins (int): Number of discrete bins used for velocity histogram features. + Defaults to ``1``. + use_velocity_features (bool): Flag indicating if velocity features should be used. + Defaults to ``True``. + use_pose_features (bool): Flag indicating if pose features should be used. + Defaults to ``True``. + use_deep_features (bool): Flag indicating if deep features should be used. + Defaults to ``True``. + n_components_velocity (int): Number of components used by GMM density estimation for velocity features. + Defaults to ``2``. + n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. + Defaults to ``1``. + n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. + Defaults to ``1``. + """ + + def __init__( + self, + box_score_thresh: float = 0.7, + persons_only: bool = False, + min_bbox_area: int = 100, + max_bbox_overlap: float = 0.65, + enable_foreground_detections: bool = True, + foreground_kernel_size: int = 3, + foreground_binary_threshold: int = 18, + n_velocity_bins: int = 1, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + n_components_velocity: int = 2, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + ) -> None: + super().__init__() + + self.model = AiVadModel( + box_score_thresh=box_score_thresh, + persons_only=persons_only, + min_bbox_area=min_bbox_area, + max_bbox_overlap=max_bbox_overlap, + enable_foreground_detections=enable_foreground_detections, + foreground_kernel_size=foreground_kernel_size, + foreground_binary_threshold=foreground_binary_threshold, + n_velocity_bins=n_velocity_bins, + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + n_components_velocity=n_components_velocity, + n_neighbors_pose=n_neighbors_pose, + n_neighbors_deep=n_neighbors_deep, + ) + + self.total_detections = 0 + + @staticmethod + def configure_optimizers() -> None: + """AI-VAD training does not involve fine-tuning of NN weights, no optimizers needed.""" + return + + def training_step(self, batch: dict[str, str | torch.Tensor]) -> None: + """Training Step of AI-VAD. + + Extract features from the batch of clips and update the density estimators. + + Args: + batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + """ + features_per_batch = self.model(batch["image"]) + + for features, video_path in zip(features_per_batch, batch["video_path"], strict=True): + self.model.density_estimator.update(features, video_path) + self.total_detections += len(next(iter(features.values()))) + + def fit(self) -> None: + """Fit the density estimators to the extracted features from the training set.""" + if self.total_detections == 0: + msg = "No regions were extracted during training." + raise ValueError(msg) + self.model.density_estimator.fit() + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of AI-VAD. + + Extract boxes and box scores.. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + *args: Arguments. + **kwargs: Keyword arguments. + + Returns: + Batch dictionary with added boxes and box scores. + """ + del args, kwargs # Unused arguments. + + boxes, anomaly_scores, image_scores = self.model(batch["image"]) + batch["pred_boxes"] = [box.int() for box in boxes] + batch["box_scores"] = [score.to(self.device) for score in anomaly_scores] + batch["pred_scores"] = torch.Tensor(image_scores).to(self.device) + + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """AI-VAD specific trainer arguments.""" + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS + + def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform | None: + """AI-VAD does not need a transform, as the region- and feature-extractors apply their own transforms.""" + del image_size + return None diff --git a/anomalib/models/video/ai_vad/regions.py b/anomalib/models/video/ai_vad/regions.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccf790eba7344a0a9b152fcb6057f8a3cb566c2 --- /dev/null +++ b/anomalib/models/video/ai_vad/regions.py @@ -0,0 +1,253 @@ +"""Regions extraction module of AI-VAD model implementation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn +from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights, maskrcnn_resnet50_fpn_v2 +from torchvision.ops import box_area, clip_boxes_to_image +from torchvision.transforms.functional import gaussian_blur, rgb_to_grayscale + +from anomalib.data.utils.boxes import boxes_to_masks, masks_to_boxes + +PERSON_LABEL = 1 + + +class RegionExtractor(nn.Module): + """Region extractor for AI-VAD. + + Args: + box_score_thresh (float): Confidence threshold for bounding box predictions. + Defaults to ``0.8``. + persons_only (bool): When enabled, only regions labeled as person are included. + Defaults to ``False``. + min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. + Defaults to ``100``. + max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. + Defaults to ``0.65``. + enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between + consecutive frames. + Defaults to ``True``. + foreground_kernel_size (int): Gaussian kernel size used in foreground detection. + Defaults to ``3``. + foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground + detection. + Defaults to ``18``. + """ + + def __init__( + self, + box_score_thresh: float = 0.8, + persons_only: bool = False, + min_bbox_area: int = 100, + max_bbox_overlap: float = 0.65, + enable_foreground_detections: bool = True, + foreground_kernel_size: int = 3, + foreground_binary_threshold: int = 18, + ) -> None: + super().__init__() + + self.persons_only = persons_only + self.min_bbox_area = min_bbox_area + self.max_bbox_overlap = max_bbox_overlap + self.enable_foreground_detections = enable_foreground_detections + self.foreground_kernel_size = foreground_kernel_size + self.foreground_binary_threshold = foreground_binary_threshold + + weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT + self.backbone = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=box_score_thresh, rpn_nms_thresh=0.3) + + def forward(self, first_frame: torch.Tensor, last_frame: torch.Tensor) -> list[dict]: + """Perform forward-pass through region extractor. + + Args: + first_frame (torch.Tensor): Batch of input images of shape (N, C, H, W) + forming the first frames in the clip. + last_frame (torch.Tensor): Batch of input images of shape (N, C, H, W) forming the last frame in the clip. + + Returns: + list[dict]: List of Mask RCNN predictions for each image in the batch. + """ + with torch.no_grad(): + regions = self.backbone(last_frame) + + if self.enable_foreground_detections: + regions = self.add_foreground_boxes( + regions, + first_frame, + last_frame, + self.foreground_kernel_size, + self.foreground_binary_threshold, + ) + + return self.post_process_bbox_detections(regions) + + def add_foreground_boxes( + self, + regions: list[dict[str, torch.Tensor]], + first_frame: torch.Tensor, + last_frame: torch.Tensor, + kernel_size: int, + binary_threshold: int, + ) -> list[dict[str, torch.Tensor]]: + """Add any foreground regions that were not detected by the region extractor. + + This method adds regions that likely belong to the foreground of the video scene, but were not detected by the + region extractor module. The foreground pixels are determined by taking the pixel difference between two + consecutive video frames and applying a binary threshold. The final detections consist of all connected + components in the foreground that do not fall in one of the bounding boxes predicted by the region extractor. + + Args: + regions (list[dict[str, torch.Tensor]]): Region detections for a batch of images, generated by the region + extraction module. + first_frame (torch.Tensor): video frame at time t-1 + last_frame (torch.Tensor): Video frame time t + kernel_size (int): Kernel size for Gaussian smoothing applied to input frames + binary_threshold (int): Binary threshold used in foreground detection, should be in range [0, 255] + + Returns: + list[dict[str, torch.Tensor]]: region detections with foreground regions appended + """ + # apply gaussian blur to first and last frame + first_frame = gaussian_blur(first_frame, [kernel_size, kernel_size]) + last_frame = gaussian_blur(last_frame, [kernel_size, kernel_size]) + + # take the abs diff between the blurred images and convert to grayscale + pixel_diff = torch.abs(first_frame - last_frame) + pixel_diff = rgb_to_grayscale(pixel_diff).squeeze(1) + + # apply binary threshold to the diff + foreground_map = (pixel_diff > binary_threshold / 255).int() + + # remove regions already detected by region extractor + boxes_list = [im_regions["boxes"] for im_regions in regions] + boxes_list = [ + clip_boxes_to_image(boxes + torch.Tensor([-2, -2, 2, 2]).to(boxes.device), foreground_map.shape[-2:]) + for boxes in boxes_list + ] # extend boxes by 2 in all directions to ensure full object is included + boxes_mask = boxes_to_masks(boxes_list, foreground_map.shape[-2:]).int() + foreground_map *= -boxes_mask + 1 # invert mask + + # find boxes from foreground map + batch_boxes, _ = masks_to_boxes(foreground_map) + + # append foreground detections to region extractor detections + for im_regions, boxes, pixel_mask in zip(regions, batch_boxes, foreground_map, strict=True): + if boxes.shape[0] == 0: + continue + + # append boxes, labels and scores + im_regions["boxes"] = torch.cat([im_regions["boxes"], boxes]) + im_regions["labels"] = torch.cat( + [im_regions["labels"], torch.zeros(boxes.shape[0], device=boxes.device)], + ) # set label as background, in accordance with region extractor predictions + im_regions["scores"] = torch.cat( + [im_regions["scores"], torch.ones(boxes.shape[0], device=boxes.device) * 0.5], + ) # set confidence to 0.5 + + # append masks + im_boxes_as_list = [box.unsqueeze(0) for box in boxes] # list with one box per element + boxes_mask = boxes_to_masks(im_boxes_as_list, pixel_mask.shape[-2:]).int() + new_masks = pixel_mask.repeat((len(im_boxes_as_list), 1, 1)) * boxes_mask + im_regions["masks"] = torch.cat([im_regions["masks"], new_masks.unsqueeze(1)]) + + return regions + + def post_process_bbox_detections(self, regions: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]: + """Post-process the region detections. + + The region detections are filtered based on class label, bbox area and overlap with other regions. + + Args: + regions (list[dict[str, torch.Tensor]]): Region detections for a batch of images, generated by the region + extraction module. + + Returns: + list[dict[str, torch.Tensor]]: Filtered regions + """ + filtered_regions_list = [] + for img_regions in regions: + filtered_regions = self._keep_only_persons(img_regions) if self.persons_only else img_regions + filtered_regions = self._filter_by_area(filtered_regions, self.min_bbox_area) + filtered_regions = self._delete_overlapping_boxes(filtered_regions, self.max_bbox_overlap) + filtered_regions_list.append(filtered_regions) + return filtered_regions_list + + def _keep_only_persons(self, regions: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Remove all region detections that are not labeled as a person by the region extractor. + + Args: + regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + + Returns: + dict[str, torch.Tensor]: Region detections from which non-person objects have been removed. + """ + keep = torch.where(regions["labels"] == PERSON_LABEL) + return self.subsample_regions(regions, keep) + + def _filter_by_area(self, regions: dict[str, torch.Tensor], min_area: int) -> dict[str, torch.Tensor]: + """Remove all regions with a surface area smaller than the specified value. + + Args: + regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + min_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. + + Returns: + dict[str, torch.Tensor]: Region detections from which small regions have been removed. + """ + areas = box_area(regions["boxes"]) + keep = torch.where(areas > min_area) + return self.subsample_regions(regions, keep) + + def _delete_overlapping_boxes(self, regions: dict[str, torch.Tensor], threshold: float) -> dict[str, torch.Tensor]: + """Delete overlapping bounding boxes. + + For each bounding box, the overlap with all other bounding boxes relative to their own surface area is computed. + When the relative overlap with any other box is higher than the specified threshold, the box is removed. when + both boxes have a relative overlap higher than the threshold, only the smaller box is removed. + + Args: + regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + threshold (float): Maximum allowed overlap between bounding boxes. + + Returns: + dict[str, torch.Tensor]: Region detections from which overlapping regions have been removed. + """ + # sort boxes by area + areas = box_area(regions["boxes"]) + indices = areas.argsort() + + keep = [] + for idx in range(len(indices)): + overlap_coords = torch.hstack( + [ + torch.max(regions["boxes"][indices[idx], :2], regions["boxes"][indices[idx + 1 :], :2]), # x1, y1 + torch.min(regions["boxes"][indices[idx], 2:], regions["boxes"][indices[idx + 1 :], 2:]), # x2, y2 + ], + ) + mask = torch.all(overlap_coords[:, :2] < overlap_coords[:, 2:], dim=1) # filter non-overlapping + overlap = box_area(overlap_coords) * mask.int() + overlap_ratio = overlap / areas[indices[idx]] + + if not any(overlap_ratio > threshold): + keep.append(indices[idx]) + + return self.subsample_regions(regions, torch.tensor(keep, dtype=torch.int64)) + + @staticmethod + def subsample_regions(regions: dict[str, torch.Tensor], indices: torch.Tensor) -> dict[str, torch.Tensor]: + """Subsample the items in a region dictionary based on a Tensor of indices. + + Args: + regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + indices (torch.Tensor): Indices of region detections that should be kept. + + Returns: + dict[str, torch.Tensor]: Subsampled region detections. + """ + new_regions_dict = {} + for key, value in regions.items(): + new_regions_dict[key] = value[indices] + return new_regions_dict diff --git a/anomalib/models/video/ai_vad/torch_model.py b/anomalib/models/video/ai_vad/torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..138d308788ae52e5016054636fffa3c4329c16e6 --- /dev/null +++ b/anomalib/models/video/ai_vad/torch_model.py @@ -0,0 +1,148 @@ +"""PyTorch model for AI-VAD model implementation. + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from torch import nn + +from .density import CombinedDensityEstimator +from .features import FeatureExtractor +from .flow import FlowExtractor +from .regions import RegionExtractor + + +class AiVadModel(nn.Module): + """AI-VAD model. + + Args: + box_score_thresh (float): Confidence threshold for region extraction stage. + Defaults to ``0.8``. + persons_only (bool): When enabled, only regions labeled as person are included. + Defaults to ``False``. + min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. + Defaults to ``100``. + max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. + Defaults to ``0.65``. + enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between + consecutive frames. + Defaults to ``True``. + foreground_kernel_size (int): Gaussian kernel size used in foreground detection. + Defaults to ``3``. + foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground + detection. + Defaults to ``18``. + n_velocity_bins (int): Number of discrete bins used for velocity histogram features. + Defaults to ``8``. + use_velocity_features (bool): Flag indicating if velocity features should be used. + Defaults to ``True``. + use_pose_features (bool): Flag indicating if pose features should be used. + Defaults to ``True``. + use_deep_features (bool): Flag indicating if deep features should be used. + Defaults to ``True``. + n_components_velocity (int): Number of components used by GMM density estimation for velocity features. + Defaults to ``5``. + n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. + Defaults to ``1``. + n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. + Defaults to ``1``. + """ + + def __init__( + self, + # region-extraction params + box_score_thresh: float = 0.8, + persons_only: bool = False, + min_bbox_area: int = 100, + max_bbox_overlap: float = 0.65, + enable_foreground_detections: bool = True, + foreground_kernel_size: int = 3, + foreground_binary_threshold: int = 18, + # feature-extraction params + n_velocity_bins: int = 8, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + # density-estimation params + n_components_velocity: int = 5, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + ) -> None: + super().__init__() + if not any((use_velocity_features, use_pose_features, use_deep_features)): + msg = "Select at least one feature type." + raise ValueError(msg) + + # initialize flow extractor + self.flow_extractor = FlowExtractor() + # initialize region extractor + self.region_extractor = RegionExtractor( + box_score_thresh=box_score_thresh, + persons_only=persons_only, + min_bbox_area=min_bbox_area, + max_bbox_overlap=max_bbox_overlap, + enable_foreground_detections=enable_foreground_detections, + foreground_kernel_size=foreground_kernel_size, + foreground_binary_threshold=foreground_binary_threshold, + ) + # initialize feature extractor + self.feature_extractor = FeatureExtractor( + n_velocity_bins=n_velocity_bins, + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + ) + # initialize density estimator + self.density_estimator = CombinedDensityEstimator( + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + n_components_velocity=n_components_velocity, + n_neighbors_pose=n_neighbors_pose, + n_neighbors_deep=n_neighbors_deep, + ) + + def forward(self, batch: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + """Forward pass through AI-VAD model. + + Args: + batch (torch.Tensor): Input image of shape (N, L, C, H, W) + + Returns: + list[torch.Tensor]: List of bbox locations for each image. + list[torch.Tensor]: List of per-bbox anomaly scores for each image. + list[torch.Tensor]: List of per-image anomaly scores. + """ + self.flow_extractor.eval() + self.region_extractor.eval() + self.feature_extractor.eval() + + # 1. get first and last frame from clip + first_frame = batch[:, 0, ...] + last_frame = batch[:, -1, ...] + + # 2. extract flows and regions + with torch.no_grad(): + flows = self.flow_extractor(first_frame, last_frame) + regions = self.region_extractor(first_frame, last_frame) + + # 3. extract pose, appearance and velocity features + features_per_batch = self.feature_extractor(first_frame, flows, regions) + + if self.training: + return features_per_batch + + # 4. estimate density + box_scores = [] + image_scores = [] + for features in features_per_batch: + box, image = self.density_estimator(features) + box_scores.append(box) + image_scores.append(image) + + box_locations = [batch_item["boxes"] for batch_item in regions] + return box_locations, box_scores, image_scores diff --git a/anomalib/py.typed b/anomalib/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/anomalib/utils/__init__.py b/anomalib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffe7654fe6a9df568952679a322221ff169872c --- /dev/null +++ b/anomalib/utils/__init__.py @@ -0,0 +1,4 @@ +"""Helpers for downloading files, calculating metrics, computing anomaly maps, and visualization.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/anomalib/utils/config.py b/anomalib/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..113522819ed0d077a50093e9be87ee869dc54fa1 --- /dev/null +++ b/anomalib/utils/config.py @@ -0,0 +1,130 @@ +"""Get configurable parameters.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections.abc import Sequence +from pathlib import Path +from typing import Any, cast + +from jsonargparse import Namespace +from jsonargparse import Path as JSONArgparsePath +from omegaconf import DictConfig, ListConfig, OmegaConf + +logger = logging.getLogger(__name__) + + +def _convert_nested_path_to_str(config: Any) -> Any: # noqa: ANN401 + """Goes over the dictionary and converts all path values to str.""" + if isinstance(config, dict): + for key, value in config.items(): + config[key] = _convert_nested_path_to_str(value) + elif isinstance(config, list): + for i, item in enumerate(config): + config[i] = _convert_nested_path_to_str(item) + elif isinstance(config, Path | JSONArgparsePath): + config = str(config) + return config + + +def to_yaml(config: Namespace | ListConfig | DictConfig) -> str: + """Convert the config to a yaml string. + + Args: + config (Namespace | ListConfig | DictConfig): Config + + Returns: + str: YAML string + """ + _config = config.clone() if isinstance(config, Namespace) else config.copy() + if isinstance(_config, Namespace): + _config = _config.as_dict() + _config = _convert_nested_path_to_str(_config) + return OmegaConf.to_yaml(_config) + + +def to_tuple(input_size: int | ListConfig) -> tuple[int, int]: + """Convert int or list to a tuple. + + Args: + input_size (int | ListConfig): input_size + + Example: + >>> to_tuple(256) + (256, 256) + >>> to_tuple([256, 256]) + (256, 256) + + Raises: + ValueError: Unsupported value type. + + Returns: + tuple[int, int]: Tuple of input_size + """ + ret_val: tuple[int, int] + if isinstance(input_size, int): + ret_val = cast(tuple[int, int], (input_size,) * 2) + elif isinstance(input_size, ListConfig | Sequence): + if len(input_size) != 2: + msg = "Expected a single integer or tuple of length 2 for width and height." + raise ValueError(msg) + + ret_val = cast(tuple[int, int], tuple(input_size)) + else: + msg = f"Expected either int or ListConfig, got {type(input_size)}" + raise TypeError(msg) + return ret_val + + +def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | ListConfig | Namespace: + """Update config. + + Args: + config: Configurable parameters. + + Returns: + DictConfig | ListConfig | Namespace: Updated config. + """ + _show_warnings(config) + + return _update_nncf_config(config) + + +def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListConfig: + """Set the NNCF input size based on the value of the crop_size parameter in the configurable parameters object. + + Args: + config (DictConfig | ListConfig): Configurable parameters of the current run. + + Returns: + DictConfig | ListConfig: Updated configurable parameters in DictConfig object. + """ + if "optimization" in config and "nncf" in config.optimization: + if "input_info" not in config.optimization.nncf: + config.optimization.nncf["input_info"] = {"sample_size": None} + config.optimization.nncf.input_info.sample_size = [1, 3, 10, 10] + if config.optimization.nncf.apply and "update_config" in config.optimization.nncf: + return OmegaConf.merge(config, config.optimization.nncf.update_config) + return config + + +def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None: + """Show warnings if any based on the configuration settings. + + Args: + config (DictConfig | ListConfig | Namespace): Configurable parameters for the current run. + """ + if "clip_length_in_frames" in config.data and config.data.init_args.clip_length_in_frames > 1: + logger.warning( + "Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. " + "Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour.", + ) + if ( + "devices" in config.trainer + and (config.trainer.devices is None or config.trainer.devices != 1) + and config.trainer.accelerator != "cpu" + ): + logger.warning("Anomalib currently does not support multi-gpu training. Setting devices to 1.") + config.trainer.devices = 1 diff --git a/anomalib/utils/cv/__init__.py b/anomalib/utils/cv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72435b61dcb1bb2d0350f23624e21d282d9820c2 --- /dev/null +++ b/anomalib/utils/cv/__init__.py @@ -0,0 +1,8 @@ +"""Anomalib computer vision utilities.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .connected_components import connected_components_cpu, connected_components_gpu + +__all__ = ["connected_components_cpu", "connected_components_gpu"] diff --git a/anomalib/utils/cv/connected_components.py b/anomalib/utils/cv/connected_components.py new file mode 100644 index 0000000000000000000000000000000000000000..e2fc1000df88a301b25515025032c3b836861345 --- /dev/null +++ b/anomalib/utils/cv/connected_components.py @@ -0,0 +1,50 @@ +"""Connected component labeling.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import cv2 +import numpy as np +import torch +from kornia.contrib import connected_components + + +def connected_components_gpu(image: torch.Tensor, num_iterations: int = 1000) -> torch.Tensor: + """Perform connected component labeling on GPU and remap the labels from 0 to N. + + Args: + image (torch.Tensor): Binary input image from which we want to extract connected components (Bx1xHxW) + num_iterations (int): Number of iterations used in the connected component computation. + + Returns: + Tensor: Components labeled from 0 to N. + """ + components = connected_components(image, num_iterations=num_iterations) + + # remap component values from 0 to N + labels = components.unique() + for new_label, old_label in enumerate(labels): + components[components == old_label] = new_label + + return components.int() + + +def connected_components_cpu(image: torch.Tensor) -> torch.Tensor: + """Perform connected component labeling on CPU. + + Args: + image (torch.Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) + + Returns: + Tensor: Components labeled from 0 to N. + """ + components = torch.zeros_like(image) + label_idx = 1 + for i, msk in enumerate(image): + mask = msk.squeeze().cpu().numpy().astype(np.uint8) + _, comps = cv2.connectedComponents(mask) + # remap component values to make sure every component has a unique value when outputs are concatenated + for label in np.unique(comps)[1:]: + components[i, 0, ...][np.where(comps == label)] = label_idx + label_idx += 1 + return components.int() diff --git a/anomalib/utils/exceptions/__init__.py b/anomalib/utils/exceptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52d64883d1bd7594f62be1f31a337ee3bcfe9dbd --- /dev/null +++ b/anomalib/utils/exceptions/__init__.py @@ -0,0 +1,8 @@ +"""Utilities related to exception and error handling.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .imports import try_import + +__all__ = ["try_import"] diff --git a/anomalib/utils/exceptions/imports.py b/anomalib/utils/exceptions/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf6f11c614b330b03f036db90e8bf1604187428 --- /dev/null +++ b/anomalib/utils/exceptions/imports.py @@ -0,0 +1,30 @@ +"""Import handling utilities.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from importlib import import_module + +logger = logging.getLogger(__name__) + + +def try_import(import_path: str) -> bool: + """Try to import a module. + + Args: + import_path (str): The import path of the module. + + Returns: + bool: True if import succeeds, False otherwise. + """ + try: + import_module(import_path) + except ImportError: + import_package = import_path.split(".")[0] + logger.warning( + f"Could not find {import_package}. To use this feature, ensure that you have {import_package} installed.", + ) + else: + return True + return False diff --git a/anomalib/utils/normalization/__init__.py b/anomalib/utils/normalization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf4493204108cc948fcf51c41cabc91ec926dc3 --- /dev/null +++ b/anomalib/utils/normalization/__init__.py @@ -0,0 +1,13 @@ +"""Tools for anomaly score normalization.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + + +class NormalizationMethod(str, Enum): + """Normalization method for normalization.""" + + MIN_MAX = "min_max" + NONE = "none" diff --git a/anomalib/utils/normalization/min_max.py b/anomalib/utils/normalization/min_max.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf2f9b23eae9696215a227156b201da8c56c849 --- /dev/null +++ b/anomalib/utils/normalization/min_max.py @@ -0,0 +1,28 @@ +"""Tools for min-max normalization.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import torch + + +def normalize( + targets: np.ndarray | np.float32 | torch.Tensor, + threshold: float | np.ndarray | torch.Tensor, + min_val: float | np.ndarray | torch.Tensor, + max_val: float | np.ndarray | torch.Tensor, +) -> np.ndarray | torch.Tensor: + """Apply min-max normalization and shift the values such that the threshold value is centered at 0.5.""" + normalized = ((targets - threshold) / (max_val - min_val)) + 0.5 + if isinstance(targets, np.ndarray | np.float32 | np.float64): + normalized = np.minimum(normalized, 1) + normalized = np.maximum(normalized, 0) + elif isinstance(targets, torch.Tensor): + normalized = torch.minimum(normalized, torch.tensor(1)) # pylint: disable=not-callable + normalized = torch.maximum(normalized, torch.tensor(0)) # pylint: disable=not-callable + else: + msg = f"Targets must be either Tensor or Numpy array. Received {type(targets)}" + raise TypeError(msg) + return normalized diff --git a/anomalib/utils/path.py b/anomalib/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..47cc77652f2f4b52350f414b71f322275003800a --- /dev/null +++ b/anomalib/utils/path.py @@ -0,0 +1,97 @@ +"""Anomalib Path Utils.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re +from pathlib import Path + + +def create_versioned_dir(root_dir: str | Path) -> Path: + """Create a new version directory and update the ``latest`` symbolic link. + + Args: + root_dir (Path): The root directory where the version directories are stored. + + Returns: + latest_link_path (Path): The path to the ``latest`` symbolic link. + + Examples: + >>> version_dir = create_version_dir(Path('path/to/experiments/')) + PosixPath('/path/to/experiments/latest') + + >>> version_dir.resolve().name + v1 + + Calling the function again will create a new version directory and + update the ``latest`` symbolic link: + + >>> version_dir = create_version_dir('path/to/experiments/') + PosixPath('/path/to/experiments/latest') + + >>> version_dir.resolve().name + v2 + + """ + # Compile a regular expression to match version directories + version_pattern = re.compile(r"^v(\d+)$") + + # Resolve the path + root_dir = Path(root_dir).resolve() + root_dir.mkdir(parents=True, exist_ok=True) + + # Find the highest existing version number + highest_version = -1 + for version_dir in root_dir.iterdir(): + if version_dir.is_dir(): + match = version_pattern.match(version_dir.name) + if match: + version_number = int(match.group(1)) + highest_version = max(highest_version, version_number) + + # The new directory will have the next highest version number + new_version_number = highest_version + 1 + new_version_dir = root_dir / f"v{new_version_number}" + + # Create the new version directory + new_version_dir.mkdir() + + # Update the 'latest' symbolic link to point to the new version directory + latest_link_path = root_dir / "latest" + if latest_link_path.is_symlink() or latest_link_path.exists(): + latest_link_path.unlink() + latest_link_path.symlink_to(new_version_dir, target_is_directory=True) + + return latest_link_path + + +def convert_to_snake_case(s: str) -> str: + """Converts a string to snake case. + + Args: + s (str): The input string to be converted. + + Returns: + str: The converted string in snake case. + + Examples: + >>> convert_to_snake_case("Snake Case") + 'snake_case' + + >>> convert_to_snake_case("snakeCase") + 'snake_case' + + >>> convert_to_snake_case("snake_case") + 'snake_case' + """ + # Replace whitespace, hyphens, periods, and apostrophes with underscores + s = re.sub(r"\s+|[-.\']", "_", s) + + # Insert underscores before capital letters (except at the beginning of the string) + s = re.sub(r"(? np.ndarray: + """Add a label to an image. + + Args: + image (np.ndarray): Input image. + label_name (str): Name of the label that will be displayed on the image. + color (tuple[int, int, int]): RGB values for background color of label. + confidence (float | None): confidence score of the label. + font_scale (float): scale of the font size relative to image size. Increase for bigger font. + thickness_scale (float): scale of the font thickness. Increase for thicker font. + + Returns: + np.ndarray: Image with label. + """ + image = image.copy() + img_height, img_width, _ = image.shape + + font = cv2.FONT_HERSHEY_PLAIN + text = label_name if confidence is None else f"{label_name} ({confidence*100:.0f}%)" + + # get font sizing + font_scale = min(img_width, img_height) * font_scale + thickness = math.ceil(min(img_width, img_height) * thickness_scale) + (width, height), baseline = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness) + + # create label + label_patch = np.zeros((height + baseline, width + baseline, 3), dtype=np.uint8) + label_patch[:, :] = color + cv2.putText( + label_patch, + text, + (0, baseline // 2 + height), + font, + fontScale=font_scale, + thickness=thickness, + color=0, + lineType=cv2.LINE_AA, + ) + + # add label to image + image[: baseline + height, : baseline + width] = label_patch + return image + + +def add_normal_label(image: np.ndarray, confidence: float | None = None) -> np.ndarray: + """Add the normal label to the image.""" + return add_label(image, "normal", (225, 252, 134), confidence) + + +def add_anomalous_label(image: np.ndarray, confidence: float | None = None) -> np.ndarray: + """Add the anomalous label to the image.""" + return add_label(image, "anomalous", (255, 100, 100), confidence) + + +def anomaly_map_to_color_map(anomaly_map: np.ndarray, normalize: bool = True) -> np.ndarray: + """Compute anomaly color heatmap. + + Args: + anomaly_map (np.ndarray): Final anomaly map computed by the distance metric. + normalize (bool, optional): Bool to normalize the anomaly map prior to applying + the color map. Defaults to True. + + Returns: + np.ndarray: [description] + """ + if normalize: + anomaly_map = (anomaly_map - anomaly_map.min()) / np.ptp(anomaly_map) + anomaly_map = anomaly_map * 255 + anomaly_map = anomaly_map.astype(np.uint8) + + anomaly_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) + return cv2.cvtColor(anomaly_map, cv2.COLOR_BGR2RGB) + + +def superimpose_anomaly_map( + anomaly_map: np.ndarray, + image: np.ndarray, + alpha: float = 0.4, + gamma: int = 0, + normalize: bool = False, +) -> np.ndarray: + """Superimpose anomaly map on top of in the input image. + + Args: + anomaly_map (np.ndarray): Anomaly map + image (np.ndarray): Input image + alpha (float, optional): Weight to overlay anomaly map + on the input image. Defaults to 0.4. + gamma (int, optional): Value to add to the blended image + to smooth the processing. Defaults to 0. Overall, + the formula to compute the blended image is + I' = (alpha*I1 + (1-alpha)*I2) + gamma + normalize: whether or not the anomaly maps should + be normalized to image min-max at image level + + + Returns: + np.ndarray: Image with anomaly map superimposed on top of it. + """ + anomaly_map = anomaly_map_to_color_map(anomaly_map.squeeze(), normalize=normalize) + return cv2.addWeighted(anomaly_map, alpha, image, (1 - alpha), gamma) + + +def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4) -> np.ndarray: + """Compute anomaly mask via thresholding the predicted anomaly map. + + Args: + anomaly_map (np.ndarray): Anomaly map predicted via the model + threshold (float): Value to threshold anomaly scores into 0-1 range. + kernel_size (int): Value to apply morphological operations to the predicted mask. Defaults to 4. + + Returns: + Predicted anomaly mask + """ + anomaly_map = anomaly_map.squeeze() + mask: np.ndarray = np.zeros_like(anomaly_map).astype(np.uint8) + mask[anomaly_map > threshold] = 1 + + kernel = morphology.disk(kernel_size) + mask = morphology.opening(mask, kernel) + + mask *= 255 + + return mask + + +def draw_boxes(image: np.ndarray, boxes: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: + """Draw bounding boxes on an image. + + Args: + image (np.ndarray): Source image. + boxes (np.nparray): 2D array of shape (N, 4) where each row contains the xyxy coordinates of a bounding box. + color (tuple[int, int, int]): Color of the drawn boxes in RGB format. + + Returns: + np.ndarray: Image showing the bounding boxes drawn on top of the source image. + """ + for box in boxes: + x_1, y_1, x_2, y_2 = box.astype(int) + image = cv2.rectangle(image, (x_1, y_1), (x_2, y_2), color=color, thickness=2) + return image diff --git a/anomalib/utils/types/__init__.py b/anomalib/utils/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a706626a40b9d5dc5c3a160f5530e68ee2a57286 --- /dev/null +++ b/anomalib/utils/types/__init__.py @@ -0,0 +1,17 @@ +"""Typing aliases for Anomalib.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import TypeAlias + +from lightning.pytorch import Callback +from omegaconf import DictConfig, ListConfig + +from anomalib.metrics.threshold import BaseThreshold +from anomalib.utils.normalization import NormalizationMethod + +NORMALIZATION: TypeAlias = NormalizationMethod | DictConfig | Callback | str +THRESHOLD: TypeAlias = ( + BaseThreshold | tuple[BaseThreshold, BaseThreshold] | DictConfig | ListConfig | list[dict[str, str | float]] | str +) diff --git a/dinov2/dinov2/__init__.py b/dinov2/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/dinov2/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/dinov2/dinov2/configs/__init__.py b/dinov2/dinov2/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68e0830c62ea19649b6cd2361995f6df309d7640 --- /dev/null +++ b/dinov2/dinov2/configs/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import pathlib + +from omegaconf import OmegaConf + + +def load_config(config_name: str): + config_filename = config_name + ".yaml" + return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) + + +dinov2_default_config = load_config("ssl_default_config") + + +def load_and_merge_config(config_name: str): + default_config = OmegaConf.create(dinov2_default_config) + loaded_config = load_config(config_name) + return OmegaConf.merge(default_config, loaded_config) diff --git a/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml b/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..117d0f027ca26cd8ce6c010bb78d5a8fac42c70e --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_base + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d53edc04a0761b4b35c147d63e04d55c90092c8f --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_base + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml b/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a96dd5b117b4d59ee210b65037821f1b3e3f16e3 --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml @@ -0,0 +1,7 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15948f8589ea0a6e04717453eb88c18388e7f1b2 --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml @@ -0,0 +1,10 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml b/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a984548bd034f762d455419d7193917fa462dd8 --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_large + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e2bc4e7b24b1a64d0369a24927996d0f184e283 --- /dev/null +++ b/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_large + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vits14_pretrain.yaml b/dinov2/dinov2/configs/eval/vits14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afbdb4ba14f1c97130a25b579360f4d817cda495 --- /dev/null +++ b/dinov2/dinov2/configs/eval/vits14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_small + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d25fd638389bfba9220792302dc9dbf5d9a2406a --- /dev/null +++ b/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_small + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/ssl_default_config.yaml b/dinov2/dinov2/configs/ssl_default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ccaae1c3174b21bcaf6e803dc861492261e5abe1 --- /dev/null +++ b/dinov2/dinov2/configs/ssl_default_config.yaml @@ -0,0 +1,118 @@ +MODEL: + WEIGHTS: '' +compute_precision: + grad_scaler: true + teacher: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + student: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 +dino: + loss_weight: 1.0 + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + separate_head: false + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + output_dir: . + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + cache_dataset: true + centering: "centering" # or "sinkhorn_knopp" +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + drop_path_uniform: true + pretrained_weights: '' + ffn_layer: "mlp" + block_chunks: 0 + qkv_bias: true + proj_bias: true + ffn_bias: true + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 +optim: + epochs: 100 + weight_decay: 0.04 + weight_decay_end: 0.4 + base_lr: 0.004 # learning rate for a batch size of 1024 + lr: 0. # will be set after applying scaling rule + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + layerwise_decay: 0.9 + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 12500 diff --git a/dinov2/dinov2/configs/train/vitg14.yaml b/dinov2/dinov2/configs/train/vitg14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d05cf0d59e07ac6e4a2b0f9bdcb6131d7c508962 --- /dev/null +++ b/dinov2/dinov2/configs/train/vitg14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 12 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/train/vitl14.yaml b/dinov2/dinov2/configs/train/vitl14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9b491dcc6a522c71328fc2933dd0501123c8f6b --- /dev/null +++ b/dinov2/dinov2/configs/train/vitl14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 32 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_large + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/dinov2/configs/train/vitl16_short.yaml b/dinov2/dinov2/configs/train/vitl16_short.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e7e72864c92175a1354142ac1d64da8070d1e5e --- /dev/null +++ b/dinov2/dinov2/configs/train/vitl16_short.yaml @@ -0,0 +1,6 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + batch_size_per_gpu: 64 +student: + block_chunks: 4 diff --git a/dinov2/dinov2/data/__init__.py b/dinov2/dinov2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ded47ea63a7b184ff74a040e2c2c514cda273ef --- /dev/null +++ b/dinov2/dinov2/data/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .adapters import DatasetWithEnumeratedTargets +from .loaders import make_data_loader, make_dataset, SamplerType +from .collate import collate_data_and_cast +from .masking import MaskingGenerator +from .augmentations import DataAugmentationDINO diff --git a/dinov2/dinov2/data/adapters.py b/dinov2/dinov2/data/adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..2097bad046fb1052267d5f2bb99c798045f00c92 --- /dev/null +++ b/dinov2/dinov2/data/adapters.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torch.utils.data import Dataset + + +class DatasetWithEnumeratedTargets(Dataset): + def __init__(self, dataset): + self._dataset = dataset + + def get_image_data(self, index: int) -> bytes: + return self._dataset.get_image_data(index) + + def get_target(self, index: int) -> Tuple[Any, int]: + target = self._dataset.get_target(index) + return (index, target) + + def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: + image, target = self._dataset[index] + target = index if target is None else target + return image, (index, target) + + def __len__(self) -> int: + return len(self._dataset) diff --git a/dinov2/dinov2/data/augmentations.py b/dinov2/dinov2/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..05b1eaa942c14f75b88d9e14732e141e8909b0a1 --- /dev/null +++ b/dinov2/dinov2/data/augmentations.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from torchvision import transforms + +from .transforms import ( + GaussianBlur, + make_normalize_transform, +) + + +logger = logging.getLogger("dinov2") + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info("###################################") + + # random resized crop and flip + self.geometric_augmentation_global = transforms.Compose( + [ + transforms.RandomResizedCrop( + global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + self.geometric_augmentation_local = transforms.Compose( + [ + transforms.RandomResizedCrop( + local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + # color distorsions / blurring + color_jittering = transforms.Compose( + [ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], + p=0.8, + ), + transforms.RandomGrayscale(p=0.2), + ] + ) + + global_transfo1_extra = GaussianBlur(p=1.0) + + global_transfo2_extra = transforms.Compose( + [ + GaussianBlur(p=0.1), + transforms.RandomSolarize(threshold=128, p=0.2), + ] + ) + + local_transfo_extra = GaussianBlur(p=0.5) + + # normalization + self.normalize = transforms.Compose( + [ + transforms.ToTensor(), + make_normalize_transform(), + ] + ) + + self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) + self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) + self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) + + def __call__(self, image): + output = {} + + # global crops: + im1_base = self.geometric_augmentation_global(image) + global_crop_1 = self.global_transfo1(im1_base) + + im2_base = self.geometric_augmentation_global(image) + global_crop_2 = self.global_transfo2(im2_base) + + output["global_crops"] = [global_crop_1, global_crop_2] + + # global crops for teacher: + output["global_crops_teacher"] = [global_crop_1, global_crop_2] + + # local crops: + local_crops = [ + self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) + ] + output["local_crops"] = local_crops + output["offsets"] = () + + return output diff --git a/dinov2/dinov2/data/collate.py b/dinov2/dinov2/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e32f357a76e6f32162cee14cb6ae1665a4827a --- /dev/null +++ b/dinov2/dinov2/data/collate.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import random + + +def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): + # dtype = torch.half # TODO: Remove + + n_global_crops = len(samples_list[0][0]["global_crops"]) + n_local_crops = len(samples_list[0][0]["local_crops"]) + + collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) + + collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) + + B = len(collated_global_crops) + N = n_tokens + n_samples_masked = int(B * mask_probability) + probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) + upperbound = 0 + masks_list = [] + for i in range(0, n_samples_masked): + prob_min = probs[i] + prob_max = probs[i + 1] + masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) + upperbound += int(N * prob_max) + for i in range(n_samples_masked, B): + masks_list.append(torch.BoolTensor(mask_generator(0))) + + random.shuffle(masks_list) + + collated_masks = torch.stack(masks_list).flatten(1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + + return { + "collated_global_crops": collated_global_crops.to(dtype), + "collated_local_crops": collated_local_crops.to(dtype), + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } diff --git a/dinov2/dinov2/data/datasets/__init__.py b/dinov2/dinov2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5550fdc5ce16269bc0c28795a389f0182e8bc6c8 --- /dev/null +++ b/dinov2/dinov2/data/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .image_net import ImageNet +from .image_net_22k import ImageNet22k diff --git a/dinov2/dinov2/data/datasets/decoders.py b/dinov2/dinov2/data/datasets/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..3769f7750d94f7e0f7bce281ef3ff186970fc9cd --- /dev/null +++ b/dinov2/dinov2/data/datasets/decoders.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from io import BytesIO +from typing import Any + +from PIL import Image + + +class Decoder: + def decode(self) -> Any: + raise NotImplementedError + + +class ImageDataDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f).convert(mode="RGB") + + +class TargetDecoder(Decoder): + def __init__(self, target: Any): + self._target = target + + def decode(self) -> Any: + return self._target diff --git a/dinov2/dinov2/data/datasets/extended.py b/dinov2/dinov2/data/datasets/extended.py new file mode 100644 index 0000000000000000000000000000000000000000..f60b619a3c797823cccfc89e262cdb230f9188f0 --- /dev/null +++ b/dinov2/dinov2/data/datasets/extended.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torchvision.datasets import VisionDataset + +from .decoders import TargetDecoder, ImageDataDecoder + + +class ExtendedVisionDataset(VisionDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # type: ignore + + def get_image_data(self, index: int) -> bytes: + raise NotImplementedError + + def get_target(self, index: int) -> Any: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + try: + image_data = self.get_image_data(index) + image = ImageDataDecoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can not read image for sample {index}") from e + target = self.get_target(index) + target = TargetDecoder(target).decode() + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + raise NotImplementedError diff --git a/dinov2/dinov2/data/datasets/image_net.py b/dinov2/dinov2/data/datasets/image_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8d08446147986c58360163e468896e994197c657 --- /dev/null +++ b/dinov2/dinov2/data/datasets/image_net.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np + +from .extended import ExtendedVisionDataset + + +logger = logging.getLogger("dinov2") +_Target = int + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" # NOTE: torchvision does not support the test split + + @property + def length(self) -> int: + split_lengths = { + _Split.TRAIN: 1_281_167, + _Split.VAL: 50_000, + _Split.TEST: 100_000, + } + return split_lengths[self] + + def get_dirname(self, class_id: Optional[str] = None) -> str: + return self.value if class_id is None else os.path.join(self.value, class_id) + + def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: + dirname = self.get_dirname(class_id) + if self == _Split.TRAIN: + basename = f"{class_id}_{actual_index}" + else: # self in (_Split.VAL, _Split.TEST): + basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" + return os.path.join(dirname, basename + ".JPEG") + + def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: + assert self != _Split.TEST + dirname, filename = os.path.split(image_relpath) + class_id = os.path.split(dirname)[-1] + basename, _ = os.path.splitext(filename) + actual_index = int(basename.split("_")[-1]) + return class_id, actual_index + + +class ImageNet(ExtendedVisionDataset): + Target = Union[_Target] + Split = Union[_Split] + + def __init__( + self, + *, + split: "ImageNet.Split", + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + self._split = split + + self._entries = None + self._class_ids = None + self._class_names = None + + @property + def split(self) -> "ImageNet.Split": + return self._split + + def _get_extra_full_path(self, extra_path: str) -> str: + return os.path.join(self._extra_root, extra_path) + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_full_path = self._get_extra_full_path(extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_full_path = self._get_extra_full_path(extra_path) + os.makedirs(self._extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _entries_path(self) -> str: + return f"entries-{self._split.value.upper()}.npy" + + @property + def _class_ids_path(self) -> str: + return f"class-ids-{self._split.value.upper()}.npy" + + @property + def _class_names_path(self) -> str: + return f"class-names-{self._split.value.upper()}.npy" + + def _get_entries(self) -> np.ndarray: + if self._entries is None: + self._entries = self._load_extra(self._entries_path) + assert self._entries is not None + return self._entries + + def _get_class_ids(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class IDs are not available in TEST split" + if self._class_ids is None: + self._class_ids = self._load_extra(self._class_ids_path) + assert self._class_ids is not None + return self._class_ids + + def _get_class_names(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class names are not available in TEST split" + if self._class_names is None: + self._class_names = self._load_extra(self._class_names_path) + assert self._class_names is not None + return self._class_names + + def find_class_id(self, class_index: int) -> str: + class_ids = self._get_class_ids() + return str(class_ids[class_index]) + + def find_class_name(self, class_index: int) -> str: + class_names = self._get_class_names() + return str(class_names[class_index]) + + def get_image_data(self, index: int) -> bytes: + entries = self._get_entries() + actual_index = entries[index]["actual_index"] + + class_id = self.get_class_id(index) + + image_relpath = self.split.get_image_relpath(actual_index, class_id) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Optional[Target]: + entries = self._get_entries() + class_index = entries[index]["class_index"] + return None if self.split == _Split.TEST else int(class_index) + + def get_targets(self) -> Optional[np.ndarray]: + entries = self._get_entries() + return None if self.split == _Split.TEST else entries["class_index"] + + def get_class_id(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_id = entries[index]["class_id"] + return None if self.split == _Split.TEST else str(class_id) + + def get_class_name(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_name = entries[index]["class_name"] + return None if self.split == _Split.TEST else str(class_name) + + def __len__(self) -> int: + entries = self._get_entries() + assert len(entries) == self.split.length + return len(entries) + + def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: + labels_full_path = os.path.join(self.root, labels_path) + labels = [] + + try: + with open(labels_full_path, "r") as f: + reader = csv.reader(f) + for row in reader: + class_id, class_name = row + labels.append((class_id, class_name)) + except OSError as e: + raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e + + return labels + + def _dump_entries(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + dataset = None + sample_count = split.length + max_class_id_length, max_class_name_length = 0, 0 + else: + labels_path = "labels.txt" + logger.info(f'loading labels from "{labels_path}"') + labels = self._load_labels(labels_path) + + # NOTE: Using torchvision ImageFolder for consistency + from torchvision.datasets import ImageFolder + + dataset_root = os.path.join(self.root, split.get_dirname()) + dataset = ImageFolder(dataset_root) + sample_count = len(dataset) + max_class_id_length, max_class_name_length = -1, -1 + for sample in dataset.samples: + _, class_index = sample + class_id, class_name = labels[class_index] + max_class_id_length = max(len(class_id), max_class_id_length) + max_class_name_length = max(len(class_name), max_class_name_length) + + dtype = np.dtype( + [ + ("actual_index", " old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + actual_index = index + 1 + class_index = np.uint32(-1) + class_id, class_name = "", "" + entries_array[index] = (actual_index, class_index, class_id, class_name) + else: + class_names = {class_id: class_name for class_id, class_name in labels} + + assert dataset + old_percent = -1 + for index in range(sample_count): + percent = 100 * (index + 1) // sample_count + if percent > old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + image_full_path, class_index = dataset.samples[index] + image_relpath = os.path.relpath(image_full_path, self.root) + class_id, actual_index = split.parse_image_relpath(image_relpath) + class_name = class_names[class_id] + entries_array[index] = (actual_index, class_index, class_id, class_name) + + logger.info(f'saving entries to "{self._entries_path}"') + self._save_extra(entries_array, self._entries_path) + + def _dump_class_ids_and_names(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + return + + entries_array = self._load_extra(self._entries_path) + + max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + max_class_name_length = max(len(str(class_name)), max_class_name_length) + + class_count = max_class_index + 1 + class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") + class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + class_ids_array[class_index] = class_id + class_names_array[class_index] = class_name + + logger.info(f'saving class IDs to "{self._class_ids_path}"') + self._save_extra(class_ids_array, self._class_ids_path) + + logger.info(f'saving class names to "{self._class_names_path}"') + self._save_extra(class_names_array, self._class_names_path) + + def dump_extra(self) -> None: + self._dump_entries() + self._dump_class_ids_and_names() diff --git a/dinov2/dinov2/data/datasets/image_net_22k.py b/dinov2/dinov2/data/datasets/image_net_22k.py new file mode 100644 index 0000000000000000000000000000000000000000..52b36a2c664a7b72e30173b03b4e2aef1cd2fcd9 --- /dev/null +++ b/dinov2/dinov2/data/datasets/image_net_22k.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from gzip import GzipFile +from io import BytesIO +from mmap import ACCESS_READ, mmap +import os +from typing import Any, Callable, List, Optional, Set, Tuple +import warnings + +import numpy as np + +from .extended import ExtendedVisionDataset + + +_Labels = int + +_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors + + +@dataclass +class _ClassEntry: + block_offset: int + maybe_filename: Optional[str] = None + + +@dataclass +class _Entry: + class_index: int # noqa: E701 + start_offset: int + end_offset: int + filename: str + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def length(self) -> int: + return { + _Split.TRAIN: 11_797_647, + _Split.VAL: 561_050, + }[self] + + def entries_path(self): + return f"imagenet21kp_{self.value}.txt" + + +def _get_tarball_path(class_id: str) -> str: + return f"{class_id}.tar" + + +def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): + @lru_cache(maxsize=mmap_cache_size) + def _mmap_tarball(class_id: str) -> mmap: + tarball_path = _get_tarball_path(class_id) + tarball_full_path = os.path.join(tarballs_root, tarball_path) + with open(tarball_full_path) as f: + return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + return _mmap_tarball + + +class ImageNet22k(ExtendedVisionDataset): + _GZIPPED_INDICES: Set[int] = { + 841_545, + 1_304_131, + 2_437_921, + 2_672_079, + 2_795_676, + 2_969_786, + 6_902_965, + 6_903_550, + 6_903_628, + 7_432_557, + 7_432_589, + 7_813_809, + 8_329_633, + 10_296_990, + 10_417_652, + 10_492_265, + 10_598_078, + 10_782_398, + 10_902_612, + 11_203_736, + 11_342_890, + 11_397_596, + 11_589_762, + 11_705_103, + 12_936_875, + 13_289_782, + } + Labels = _Labels + + def __init__( + self, + *, + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + + entries_path = self._get_entries_path(root) + self._entries = self._load_extra(entries_path) + + class_ids_path = self._get_class_ids_path(root) + self._class_ids = self._load_extra(class_ids_path) + + self._gzipped_indices = ImageNet22k._GZIPPED_INDICES + self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) + + def _get_entries_path(self, root: Optional[str] = None) -> str: + return "entries.npy" + + def _get_class_ids_path(self, root: Optional[str] = None) -> str: + return "class-ids.npy" + + def _find_class_ids(self, path: str) -> List[str]: + class_ids = [] + + with os.scandir(path) as entries: + for entry in entries: + root, ext = os.path.splitext(entry.name) + if ext != ".tar": + continue + class_ids.append(root) + + return sorted(class_ids) + + def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: + root = self.get_root(root) + entries: List[_Entry] = [] + class_ids = self._find_class_ids(root) + + for class_index, class_id in enumerate(class_ids): + path = os.path.join(root, "blocks", f"{class_id}.log") + class_entries = [] + + try: + with open(path) as f: + for line in f: + line = line.rstrip() + block, filename = line.split(":") + block_offset = int(block[6:]) + filename = filename[1:] + + maybe_filename = None + if filename != "** Block of NULs **": + maybe_filename = filename + _, ext = os.path.splitext(filename) + # assert ext == ".JPEG" + + class_entry = _ClassEntry(block_offset, maybe_filename) + class_entries.append(class_entry) + except OSError as e: + raise RuntimeError(f'can not read blocks file "{path}"') from e + + assert class_entries[-1].maybe_filename is None + + for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): + assert class_entry1.block_offset <= class_entry2.block_offset + start_offset = 512 * class_entry1.block_offset + end_offset = 512 * class_entry2.block_offset + assert class_entry1.maybe_filename is not None + filename = class_entry1.maybe_filename + entry = _Entry(class_index, start_offset, end_offset, filename) + # Skip invalid image files (PIL throws UnidentifiedImageError) + if filename == "n06470073_47249.JPEG": + continue + entries.append(entry) + + return entries, class_ids + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + os.makedirs(extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _tarballs_root(self) -> str: + return self.root + + def find_class_id(self, class_index: int) -> str: + return str(self._class_ids[class_index]) + + def get_image_data(self, index: int) -> bytes: + entry = self._entries[index] + class_id = entry["class_id"] + class_mmap = self._mmap_tarball(class_id) + + start_offset, end_offset = entry["start_offset"], entry["end_offset"] + try: + mapped_data = class_mmap[start_offset:end_offset] + data = mapped_data[512:] # Skip entry header block + + if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): + assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" + with GzipFile(fileobj=BytesIO(data)) as g: + data = g.read() + except Exception as e: + raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e + + return data + + def get_target(self, index: int) -> Any: + return int(self._entries[index]["class_index"]) + + def get_targets(self) -> np.ndarray: + return self._entries["class_index"] + + def get_class_id(self, index: int) -> str: + return str(self._entries[index]["class_id"]) + + def get_class_ids(self) -> np.ndarray: + return self._entries["class_id"] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().__getitem__(index) + + def __len__(self) -> int: + return len(self._entries) + + def _dump_entries(self, *args, **kwargs) -> None: + entries, class_ids = self._load_entries_class_ids(*args, **kwargs) + + max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 + for entry in entries: + class_id = class_ids[entry.class_index] + max_class_index = max(entry.class_index, max_class_index) + max_class_id_length = max(len(class_id), max_class_id_length) + max_filename_length = max(len(entry.filename), max_filename_length) + + dtype = np.dtype( + [ + ("class_index", " None: + entries_path = self._get_entries_path(*args, **kwargs) + entries_array = self._load_extra(entries_path) + + max_class_id_length, max_class_index = -1, -1 + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + + class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + class_ids_array[class_index] = class_id + class_ids_path = self._get_class_ids_path(*args, **kwargs) + self._save_extra(class_ids_array, class_ids_path) + + def _dump_extra(self, *args, **kwargs) -> None: + self._dump_entries(*args, *kwargs) + self._dump_class_ids(*args, *kwargs) + + def dump_extra(self, root: Optional[str] = None) -> None: + return self._dump_extra(root) diff --git a/dinov2/dinov2/data/loaders.py b/dinov2/dinov2/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a2f0210efa0fa96be764665b5d6792191b1e72 --- /dev/null +++ b/dinov2/dinov2/data/loaders.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum +from typing import Any, Callable, List, Optional, TypeVar + +import torch +from torch.utils.data import Sampler + +from .datasets import ImageNet, ImageNet22k +from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler + + +logger = logging.getLogger("dinov2") + + +class SamplerType(Enum): + DISTRIBUTED = 0 + EPOCH = 1 + INFINITE = 2 + SHARDED_INFINITE = 3 + SHARDED_INFINITE_NEW = 4 + + +def _make_bool_str(b: bool) -> str: + return "yes" if b else "no" + + +def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): + def transform(sample): + image, target = sample + if image_transform is not None: + image = image_transform(image) + if target_transform is not None: + target = target_transform(target) + return image, target + + return transform + + +def _parse_dataset_str(dataset_str: str): + tokens = dataset_str.split(":") + + name = tokens[0] + kwargs = {} + + for token in tokens[1:]: + key, value = token.split("=") + assert key in ("root", "extra", "split") + kwargs[key] = value + + if name == "ImageNet": + class_ = ImageNet + if "split" in kwargs: + kwargs["split"] = ImageNet.Split[kwargs["split"]] + elif name == "ImageNet22k": + class_ = ImageNet22k + else: + raise ValueError(f'Unsupported dataset "{name}"') + + return class_, kwargs + + +def make_dataset( + *, + dataset_str: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +): + """ + Creates a dataset with the specified parameters. + + Args: + dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). + transform: A transform to apply to images. + target_transform: A transform to apply to targets. + + Returns: + The created dataset. + """ + logger.info(f'using dataset: "{dataset_str}"') + + class_, kwargs = _parse_dataset_str(dataset_str) + dataset = class_(transform=transform, target_transform=target_transform, **kwargs) + + logger.info(f"# of dataset samples: {len(dataset):,d}") + + # Aggregated datasets do not expose (yet) these attributes, so add them. + if not hasattr(dataset, "transform"): + setattr(dataset, "transform", transform) + if not hasattr(dataset, "target_transform"): + setattr(dataset, "target_transform", target_transform) + + return dataset + + +def _make_sampler( + *, + dataset, + type: Optional[SamplerType] = None, + shuffle: bool = False, + seed: int = 0, + size: int = -1, + advance: int = 0, +) -> Optional[Sampler]: + sample_count = len(dataset) + + if type == SamplerType.INFINITE: + logger.info("sampler: infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + return InfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + ) + elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): + logger.info("sampler: sharded infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + # TODO: Remove support for old shuffling + use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW + return ShardedInfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, + ) + elif type == SamplerType.EPOCH: + logger.info("sampler: epoch") + if advance > 0: + raise NotImplementedError("sampler advance > 0 is not supported") + size = size if size > 0 else sample_count + logger.info(f"# of samples / epoch: {size:,d}") + return EpochSampler( + size=size, + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + ) + elif type == SamplerType.DISTRIBUTED: + logger.info("sampler: distributed") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + if advance > 0: + raise ValueError("sampler advance > 0 is invalid") + return torch.utils.data.DistributedSampler( + dataset=dataset, + shuffle=shuffle, + seed=seed, + drop_last=False, + ) + + logger.info("sampler: none") + return None + + +T = TypeVar("T") + + +def make_data_loader( + *, + dataset, + batch_size: int, + num_workers: int, + shuffle: bool = True, + seed: int = 0, + sampler_type: Optional[SamplerType] = SamplerType.INFINITE, + sampler_size: int = -1, + sampler_advance: int = 0, + drop_last: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable[[List[T]], Any]] = None, +): + """ + Creates a data loader with the specified parameters. + + Args: + dataset: A dataset (third party, LaViDa or WebDataset). + batch_size: The size of batches to generate. + num_workers: The number of workers to use. + shuffle: Whether to shuffle samples. + seed: The random seed to use. + sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. + sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. + sampler_advance: How many samples to skip (when applicable). + drop_last: Whether the last non-full batch of data should be dropped. + persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. + collate_fn: Function that performs batch collation + """ + + sampler = _make_sampler( + dataset=dataset, + type=sampler_type, + shuffle=shuffle, + seed=seed, + size=sampler_size, + advance=sampler_advance, + ) + + logger.info("using PyTorch data loader") + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + ) + + try: + logger.info(f"# of batches: {len(data_loader):,d}") + except TypeError: # data loader has no length + logger.info("infinite data loader") + return data_loader diff --git a/dinov2/dinov2/data/masking.py b/dinov2/dinov2/data/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..ab12aa7bf138b916b16a9a2ed1a628a2759dbec6 --- /dev/null +++ b/dinov2/dinov2/data/masking.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import random +import math +import numpy as np + + +class MaskingGenerator: + def __init__( + self, + input_size, + num_masking_patches=None, + min_num_patches=4, + max_num_patches=None, + min_aspect=0.3, + max_aspect=None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.min_num_patches, + self.max_num_patches, + self.num_masking_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, num_masking_patches=0): + mask = np.zeros(shape=self.get_shape(), dtype=bool) + mask_count = 0 + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/dinov2/dinov2/data/samplers.py b/dinov2/dinov2/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..6562197d94652bb9a75a5fc722fcb2c65ca161be --- /dev/null +++ b/dinov2/dinov2/data/samplers.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +from typing import Any, Optional +import warnings + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +import dinov2.distributed as distributed + + +class EpochSampler(Sampler): + def __init__( + self, + *, + size: int, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + ): + self._size = size + self._sample_count = sample_count + self._shuffle = shuffle + self._seed = seed + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._epoch = 0 + + def __iter__(self): + count = (self._size + self._sample_count - 1) // self._sample_count + tiled_indices = np.tile(np.arange(self._sample_count), count) + if self._shuffle: + seed = self._seed * self._epoch if self._seed != 0 else self._epoch + rng = np.random.default_rng(seed) + iterable = rng.choice(tiled_indices, self._size, replace=False) + else: + iterable = tiled_indices[: self._size] + + yield from itertools.islice(iterable, self._start, None, self._step) + + def __len__(self): + return (self._size - self._start + self._step - 1) // self._step + + def set_epoch(self, epoch): + self._epoch = epoch + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +class InfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + + def __iter__(self): + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator().manual_seed(self._seed) + + while True: + iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) + yield from itertools.islice(iterable, self._start, None, self._step) + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 diff --git a/dinov2/dinov2/data/transforms.py b/dinov2/dinov2/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5f252b50c54d58f160528c9f2b00fad47103c7 --- /dev/null +++ b/dinov2/dinov2/data/transforms.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +from torchvision import transforms + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) diff --git a/dinov2/dinov2/distributed/__init__.py b/dinov2/dinov2/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23226f4536bf5acf4ffac242e9903d92863b246d --- /dev/null +++ b/dinov2/dinov2/distributed/__init__.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +import random +import re +import socket +from typing import Dict, List + +import torch +import torch.distributed as dist + +_LOCAL_RANK = -1 +_LOCAL_WORLD_SIZE = -1 + + +def is_enabled() -> bool: + """ + Returns: + True if distributed training is enabled + """ + return dist.is_available() and dist.is_initialized() + + +def get_global_size() -> int: + """ + Returns: + The number of processes in the process group + """ + return dist.get_world_size() if is_enabled() else 1 + + +def get_global_rank() -> int: + """ + Returns: + The rank of the current process within the global process group. + """ + return dist.get_rank() if is_enabled() else 0 + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_enabled(): + return 0 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_RANK + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_enabled(): + return 1 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_WORLD_SIZE + + +def is_main_process() -> bool: + """ + Returns: + True if the current process is the main one. + """ + return get_global_rank() == 0 + + +def _restrict_print_to_main_process() -> None: + """ + This function disables printing when not in the main process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main_process() or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def _get_master_port(seed: int = 0) -> int: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(seed) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + return int(master_port_str) + + +def _get_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # A "" host address means INADDR_ANY i.e. binding to all interfaces. + # Note this is not compatible with IPv6. + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +_TORCH_DISTRIBUTED_ENV_VARS = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", +) + + +def _collect_env_vars() -> Dict[str, str]: + return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} + + +def _is_slurm_job_process() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def _parse_slurm_node_list(s: str) -> List[str]: + nodes = [] + # Extract "hostname", "hostname[1-2,3,4-5]," substrings + p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") + for m in p.finditer(s): + prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] + for suffix in suffixes.split(","): + span = suffix.split("-") + if len(span) == 1: + nodes.append(prefix + suffix) + else: + width = len(span[0]) + start, end = int(span[0]), int(span[1]) + 1 + nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) + return nodes + + +def _check_env_variable(key: str, new_value: str): + # Only check for difference with preset environment variables + if key in os.environ and os.environ[key] != new_value: + raise RuntimeError(f"Cannot export environment variables as {key} is already set") + + +class _TorchDistributedEnvironment: + def __init__(self): + self.master_addr = "127.0.0.1" + self.master_port = 0 + self.rank = -1 + self.world_size = -1 + self.local_rank = -1 + self.local_world_size = -1 + + if _is_slurm_job_process(): + return self._set_from_slurm_env() + + env_vars = _collect_env_vars() + if not env_vars: + # Environment is not set + pass + elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): + # Environment is fully set + return self._set_from_preset_env() + else: + # Environment is partially set + collected_env_vars = ", ".join(env_vars.keys()) + raise RuntimeError(f"Partially set environment: {collected_env_vars}") + + if torch.cuda.device_count() > 0: + return self._set_from_local() + + raise RuntimeError("Can't initialize PyTorch distributed environment") + + # Slurm job created with sbatch, submitit, etc... + def _set_from_slurm_env(self): + # logger.info("Initialization from Slurm environment") + job_id = int(os.environ["SLURM_JOB_ID"]) + node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) + nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) + assert len(nodes) == node_count + + self.master_addr = nodes[0] + self.master_port = _get_master_port(seed=job_id) + self.rank = int(os.environ["SLURM_PROCID"]) + self.world_size = int(os.environ["SLURM_NTASKS"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["SLURM_LOCALID"]) + self.local_world_size = self.world_size // node_count + assert self.local_rank < self.local_world_size + + # Single node job with preset environment (i.e. torchrun) + def _set_from_preset_env(self): + # logger.info("Initialization from preset environment") + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + assert self.local_rank < self.local_world_size + + # Single node and GPU job (i.e. local script run) + def _set_from_local(self): + # logger.info("Initialization from local") + self.master_addr = "127.0.0.1" + self.master_port = _get_available_port() + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.local_world_size = 1 + + def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), + "LOCAL_WORLD_SIZE": str(self.local_world_size), + } + if not overwrite: + for k, v in env_vars.items(): + _check_env_variable(k, v) + + os.environ.update(env_vars) + return self + + +def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): + """Enable distributed mode + + Args: + set_cuda_current_device: If True, call torch.cuda.set_device() to set the + current PyTorch CUDA device to the one matching the local rank. + overwrite: If True, overwrites already set variables. Else fails. + """ + + global _LOCAL_RANK, _LOCAL_WORLD_SIZE + if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: + raise RuntimeError("Distributed mode has already been enabled") + torch_env = _TorchDistributedEnvironment() + torch_env.export(overwrite=overwrite) + + if set_cuda_current_device: + torch.cuda.set_device(torch_env.local_rank) + + if allow_nccl_timeout: + # This allows to use torch distributed timeout in a NCCL backend + key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" + if not overwrite: + _check_env_variable(key, value) + os.environ[key] = value + + dist.init_process_group(backend="nccl") + dist.barrier() + + # Finalize setup + _LOCAL_RANK = torch_env.local_rank + _LOCAL_WORLD_SIZE = torch_env.local_world_size + _restrict_print_to_main_process() diff --git a/dinov2/dinov2/eval/__init__.py b/dinov2/dinov2/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/eval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/eval/depth/__init__.py b/dinov2/dinov2/eval/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/eval/depth/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/eval/depth/models/__init__.py b/dinov2/dinov2/eval/depth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5825181dc2189424b5c58d245b36919cbc5b2e --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss +from .decode_heads import * # noqa: F403 +from .depther import * # noqa: F403 +from .losses import * # noqa: F403 diff --git a/dinov2/dinov2/eval/depth/models/backbones/__init__.py b/dinov2/dinov2/eval/depth/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py b/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..69bda46fd69eb7dabb8f5b60e6fa459fdc21aeab --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__(self, *args, **kwargs): + super().__init__() diff --git a/dinov2/dinov2/eval/depth/models/builder.py b/dinov2/dinov2/eval/depth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c152643435308afcff60b07cd68ea979fe1d90cb --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/builder.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION +from mmcv.utils import Registry + +MODELS = Registry("models", parent=MMCV_MODELS) +ATTENTION = Registry("attention", parent=MMCV_ATTENTION) + + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +DEPTHER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_depther(cfg, train_cfg=None, test_cfg=None): + """Build depther.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning) + assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field " + assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field " + return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py b/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0f0754a5b01d7622c1f26bf3f60daea19da4e8 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dpt_head import DPTHead +from .linear_head import BNHead diff --git a/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py b/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c867a3ec687090b280d90bb86aee435320acda --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from abc import ABCMeta, abstractmethod + +import mmcv +import numpy as np +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 + +from ...ops import resize +from ..builder import build_loss + + +class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_cfg (dict|None): Config of conv layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + loss_decode (dict): Config of decode loss. + Default: dict(type='SigLoss'). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_cfg (dict|None): Config of norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + channels=96, + conv_cfg=None, + act_cfg=dict(type="ReLU"), + loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_cfg=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.act_cfg = act_cfg + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_cfg = norm_cfg + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.fp16_enabled = False + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def extra_repr(self): + """Extra repr.""" + s = f"align_corners={self.align_corners}" + return s + + @auto_fp16() + @abstractmethod + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + @force_fp32(apply_to=("depth_pred",)) + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = mmcv.imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} diff --git a/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py b/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c6d9470d78e1d944cc505f97865f026a9458d3 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmcv.runner import BaseModule + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__( + self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None + ): + super(ReassembleBlocks, self).__init__(init_cfg) + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU"))) + ) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None): + super(PreActResidualConvUnit, self).__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None): + super(FeatureFusionBlock, self).__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +@HEADS.register_module() +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py b/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3da1436f6a3f0bcc389d74ed86d44d455d2f7a87 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +@HEADS.register_module() +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + + return output diff --git a/dinov2/dinov2/eval/depth/models/depther/__init__.py b/dinov2/dinov2/eval/depth/models/depther/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be99743bf6c773d05f2b74524116e368c0cfcba0 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/depther/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .base import BaseDepther +from .encoder_decoder import DepthEncoderDecoder diff --git a/dinov2/dinov2/eval/depth/models/depther/base.py b/dinov2/dinov2/eval/depth/models/depther/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e133a825a888167f90d95d67803609d6cac7ff55 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/depther/base.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseDepther(BaseModule, metaclass=ABCMeta): + """Base class for depther.""" + + def __init__(self, init_cfg=None): + super(BaseDepther, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the depther has neck""" + return hasattr(self, "neck") and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the depther has auxiliary head""" + return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the depther has decode head""" + return hasattr(self, "decode_head") and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic depth map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=("img",)) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py b/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0ec2dd314fdf8ccf4414d81afb95326b7dc0c9 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ...models import builder +from ...models.builder import DEPTHER +from ...ops import resize +from .base import BaseDepther + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +@DEPTHER.register_module() +class DepthEncoderDecoder(BaseDepther): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone, (neck) and decode_head. + """ + + def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): + super(DepthEncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + self._init_decode_head(decode_head) + + if neck is not None: + self.neck = builder.build_neck(neck) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + depth_pred = self.encode_decode(img, img_meta, rescale, size=size) + + return depth_pred + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred diff --git a/dinov2/dinov2/eval/depth/models/losses/__init__.py b/dinov2/dinov2/eval/depth/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f86242e342776da2e0acc61150d15a8d58ff1e0 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/losses/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .gradientloss import GradientLoss +from .sigloss import SigLoss diff --git a/dinov2/dinov2/eval/depth/models/losses/gradientloss.py b/dinov2/dinov2/eval/depth/models/losses/gradientloss.py new file mode 100644 index 0000000000000000000000000000000000000000..1599878a6b70cdff4f8467e1e875f0d13ea89eca --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/losses/gradientloss.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class GradientLoss(nn.Module): + """GradientLoss. + + Adapted from https://www.cs.cornell.edu/projects/megadepth/ + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + """ + + def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"): + super(GradientLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + def gradientloss(self, input, target): + input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)] + target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)] + + gradient_loss = 0 + for input, target in zip(input_downscaled, target_downscaled): + if self.valid_mask: + mask = target > 0 + if self.max_depth is not None: + mask = torch.logical_and(target > 0, target <= self.max_depth) + N = torch.sum(mask) + else: + mask = torch.ones_like(target) + N = input.numel() + input_log = torch.log(input + self.eps) + target_log = torch.log(target + self.eps) + log_d_diff = input_log - target_log + + log_d_diff = torch.mul(log_d_diff, mask) + + v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :]) + v_mask = torch.mul(mask[0:-2, :], mask[2:, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:]) + h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:]) + h_gradient = torch.mul(h_gradient, h_mask) + + gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N + + return gradient_loss + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt) + return gradient_loss diff --git a/dinov2/dinov2/eval/depth/models/losses/sigloss.py b/dinov2/dinov2/eval/depth/models/losses/sigloss.py new file mode 100644 index 0000000000000000000000000000000000000000..e12fad3e6151e4b975dd055193fdaec0206d4a14 --- /dev/null +++ b/dinov2/dinov2/eval/depth/models/losses/sigloss.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class SigLoss(nn.Module): + """SigLoss. + + This follows `AdaBins `_. + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + warm_up (bool): A simple warm up stage to help convergence. Default: False. + warm_iter (int): The number of warm up stage. Default: 100. + """ + + def __init__( + self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" + ): + super(SigLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + # HACK: a hack implementation for warmup sigloss + self.warm_up = warm_up + self.warm_iter = warm_iter + self.warm_up_counter = 0 + + def sigloss(self, input, target): + if self.valid_mask: + valid_mask = target > 0 + if self.max_depth is not None: + valid_mask = torch.logical_and(target > 0, target <= self.max_depth) + input = input[valid_mask] + target = target[valid_mask] + + if self.warm_up: + if self.warm_up_counter < self.warm_iter: + g = torch.log(input + self.eps) - torch.log(target + self.eps) + g = 0.15 * torch.pow(torch.mean(g), 2) + self.warm_up_counter += 1 + return torch.sqrt(g) + + g = torch.log(input + self.eps) - torch.log(target + self.eps) + Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) + return torch.sqrt(Dg) + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) + return loss_depth diff --git a/dinov2/dinov2/eval/depth/ops/__init__.py b/dinov2/dinov2/eval/depth/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78181c29581a281b5f42cf12078636aaeb43b5a5 --- /dev/null +++ b/dinov2/dinov2/eval/depth/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .wrappers import resize diff --git a/dinov2/dinov2/eval/depth/ops/wrappers.py b/dinov2/dinov2/eval/depth/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/dinov2/dinov2/eval/depth/ops/wrappers.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/dinov2/dinov2/eval/knn.py b/dinov2/dinov2/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a4845da1313a6db6b8345bb9a98230fcd24acf --- /dev/null +++ b/dinov2/dinov2/eval/knn.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import torch +from torch.nn.functional import one_hot, softmax + +import dinov2.distributed as distributed +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--nb_knn", + nargs="+", + type=int, + help="Number of NN to use. 20 is usually working the best.", + ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature used in the voting coefficient", + ) + parser.add_argument( + "--gather-on-cpu", + action="store_true", + help="Whether to gather the train features on cpu, slower" + "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + ) + parser.add_argument( + "--n-per-class-list", + nargs="+", + type=int, + help="Number to take per class", + ) + parser.add_argument( + "--n-tries", + type=int, + help="Number of tries", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=[10, 20, 100, 200], + temperature=0.07, + batch_size=256, + n_per_class_list=[-1], + n_tries=1, + ) + return parser + + +class KnnModule(torch.nn.Module): + """ + Gets knn of test features from all processes on a chunk of the train features + + Each rank gets a chunk of the train features as well as a chunk of the test features. + In `compute_neighbors`, for each rank one after the other, its chunk of test features + is sent to all devices, partial knns are computed with each chunk of train features + then collated back on the original device. + """ + + def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): + super().__init__() + + self.global_rank = distributed.get_global_rank() + self.global_size = distributed.get_global_size() + + self.device = device + self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) + self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) + + self.nb_knn = nb_knn + self.max_k = max(self.nb_knn) + self.T = T + self.num_classes = num_classes + + def _get_knn_sims_and_labels(self, similarity, train_labels): + topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) + neighbors_labels = torch.gather(train_labels, 1, indices) + return topk_sims, neighbors_labels + + def _similarity_for_rank(self, features_rank, source_rank): + # Send the features from `source_rank` to all ranks + broadcast_shape = torch.tensor(features_rank.shape).to(self.device) + torch.distributed.broadcast(broadcast_shape, source_rank) + + broadcasted = features_rank + if self.global_rank != source_rank: + broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) + torch.distributed.broadcast(broadcasted, source_rank) + + # Compute the neighbors for `source_rank` among `train_features_rank_T` + similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) + candidate_labels = self.candidates.expand(len(similarity_rank), -1) + return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) + + def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): + # Gather all neighbors for `target_rank` + topk_sims_rank = retrieved_rank = None + if self.global_rank == target_rank: + topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] + retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] + + torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) + torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) + + if self.global_rank == target_rank: + # Perform a second top-k on the k * global_size retrieved neighbors + topk_sims_rank = torch.cat(topk_sims_rank, dim=1) + retrieved_rank = torch.cat(retrieved_rank, dim=1) + results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) + return results + return None + + def compute_neighbors(self, features_rank): + for rank in range(self.global_size): + topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) + results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) + if results is not None: + topk_sims_rank, neighbors_labels_rank = results + return topk_sims_rank, neighbors_labels_rank + + def forward(self, features_rank): + """ + Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` + """ + assert all(k <= self.max_k for k in self.nb_knn) + + topk_sims, neighbors_labels = self.compute_neighbors(features_rank) + batch_size = neighbors_labels.shape[0] + topk_sims_transform = softmax(topk_sims / self.T, 1) + matmul = torch.mul( + one_hot(neighbors_labels, num_classes=self.num_classes), + topk_sims_transform.view(batch_size, -1, 1), + ) + probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} + return probas_for_k + + +class DictKeysModule(torch.nn.Module): + def __init__(self, keys): + super().__init__() + self.keys = keys + + def forward(self, features_dict, targets): + for k in self.keys: + features_dict = features_dict[k] + return {"preds": features_dict, "target": targets} + + +def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): + modules = {} + mapping = create_class_indices_mapping(train_labels) + for npc in n_per_class_list: + if npc < 0: # Only one try needed when using the full data + full_module = module( + train_features=train_features, + train_labels=train_labels, + nb_knn=nb_knn, + ) + modules["full"] = ModuleDictWithForward({"1": full_module}) + continue + all_tries = {} + for t in range(n_tries): + final_indices = filter_train(mapping, npc, seed=t) + k_list = list(set(nb_knn + [npc])) + k_list = sorted([el for el in k_list if el <= npc]) + all_tries[str(t)] = module( + train_features=train_features[final_indices], + train_labels=train_labels[final_indices], + nb_knn=k_list, + ) + modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) + + return ModuleDictWithForward(modules) + + +def filter_train(mapping, n_per_class, seed): + torch.manual_seed(seed) + final_indices = [] + for k in mapping.keys(): + index = torch.randperm(len(mapping[k]))[:n_per_class] + final_indices.append(mapping[k][index]) + return torch.cat(final_indices).squeeze() + + +def create_class_indices_mapping(labels): + unique_labels, inverse = torch.unique(labels, return_inverse=True) + mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} + return mapping + + +class ModuleDictWithForward(torch.nn.ModuleDict): + def forward(self, *args, **kwargs): + return {k: module(*args, **kwargs) for k, module in self._modules.items()} + + +def eval_knn( + model, + train_dataset, + val_dataset, + accuracy_averaging, + nb_knn, + temperature, + batch_size, + num_workers, + gather_on_cpu, + n_per_class_list=[-1], + n_tries=1, +): + model = ModelWithNormalize(model) + + logger.info("Extracting features for train set...") + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu + ) + logger.info(f"Train features created, shape {train_features.shape}.") + + val_dataloader = make_data_loader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + num_classes = train_labels.max() + 1 + metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) + + device = torch.cuda.current_device() + partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) + knn_module_dict = create_module_dict( + module=partial_module, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + nb_knn=nb_knn, + train_features=train_features, + train_labels=train_labels, + ) + postprocessors, metrics = {}, {} + for n_per_class, knn_module in knn_module_dict.items(): + for t, knn_try in knn_module.items(): + postprocessors = { + **postprocessors, + **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, + } + metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} + model_with_knn = torch.nn.Sequential(model, knn_module_dict) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) + + # Averaging the results over the n tries for each value of n_per_class + for n_per_class, knn_module in knn_module_dict.items(): + first_try = list(knn_module.keys())[0] + k_list = knn_module[first_try].nb_knn + for k in k_list: + keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` + results_dict[(n_per_class, k)] = { + key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) + for key in keys + } + for t in knn_module.keys(): + del results_dict[(n_per_class, t, k)] + + return results_dict + + +def eval_knn_with_model( + model, + output_dir, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=(10, 20, 100, 200), + temperature=0.07, + autocast_dtype=torch.float, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=False, + batch_size=256, + num_workers=5, + n_per_class_list=[-1], + n_tries=1, +): + transform = transform or make_classification_eval_transform() + + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=transform, + ) + val_dataset = make_dataset( + dataset_str=val_dataset_str, + transform=transform, + ) + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_knn = eval_knn( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + accuracy_averaging=accuracy_averaging, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + gather_on_cpu=gather_on_cpu, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + ) + + results_dict = {} + if distributed.is_main_process(): + for knn_ in results_dict_knn.keys(): + top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 + top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 + results_dict[f"{knn_} Top 1"] = top1 + results_dict[f"{knn_} Top 5"] = top5 + logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") + + metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_knn_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + nb_knn=args.nb_knn, + temperature=args.temperature, + autocast_dtype=autocast_dtype, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=args.gather_on_cpu, + batch_size=args.batch_size, + num_workers=5, + n_per_class_list=args.n_per_class_list, + n_tries=args.n_tries, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 k-NN evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/dinov2/dinov2/eval/linear.py b/dinov2/dinov2/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd4c5de5a041be8a188f007257d1e91b6d6921e --- /dev/null +++ b/dinov2/dinov2/eval/linear.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform +import dinov2.distributed as distributed +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--test-datasets", + dest="test_dataset_strs", + type=str, + nargs="+", + help="Test datasets, none to reuse the validation dataset", + ) + parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch Size (per GPU)", + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number de Workers", + ) + parser.add_argument( + "--epoch-length", + type=int, + help="Length of an epoch in number of iterations", + ) + parser.add_argument( + "--save-checkpoint-frequency", + type=int, + help="Number of epochs between two named checkpoint saves.", + ) + parser.add_argument( + "--eval-period-iterations", + type=int, + help="Number of iterations between two evaluations.", + ) + parser.add_argument( + "--learning-rates", + nargs="+", + type=float, + help="Learning rates to grid search.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not resume from existing checkpoints", + ) + parser.add_argument( + "--val-metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--test-metric-types", + type=MetricType, + choices=list(MetricType), + nargs="+", + help="Evaluation metric", + ) + parser.add_argument( + "--classifier-fpath", + type=str, + help="Path to a file containing pretrained linear classifiers", + ) + parser.add_argument( + "--val-class-mapping-fpath", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--test-class-mapping-fpaths", + nargs="+", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + test_dataset_strs=None, + epochs=10, + batch_size=128, + num_workers=8, + epoch_length=1250, + save_checkpoint_frequency=20, + eval_period_iterations=1250, + learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + ) + return parser + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def _pad_and_collate(batch): + maxlen = max(len(targets) for image, targets in batch) + padded_batch = [ + (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch + ] + return torch.utils.data.default_collate(padded_batch) + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 + + +def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): + linear_classifiers_dict = nn.ModuleDict() + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [False, True]: + for _lr in learning_rates: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] + linear_classifier = LinearClassifier( + out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +@torch.no_grad() +def evaluate_linear_classifiers( + feature_model, + linear_classifiers, + data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=None, + best_classifier_on_val=None, +): + logger.info("running validation !") + + num_classes = len(class_mapping) if class_mapping is not None else training_num_classes + metric = build_metric(metric_type, num_classes=num_classes) + postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp = evaluate( + feature_model, + data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + if distributed.is_main_process(): + with open(metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict + + +def eval_linear( + *, + feature_model, + linear_classifiers, + train_data_loader, + val_data_loader, + metrics_file_path, + optimizer, + scheduler, + output_dir, + max_iter, + checkpoint_period, # In number of iter, creates a new file every period + running_checkpoint_period, # Period to update main checkpoint file + eval_period, + metric_type, + training_num_classes, + resume=True, + classifier_fpath=None, + val_class_mapping=None, +): + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + + periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + header = "Training" + + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + features = feature_model(data) + outputs = linear_classifiers(features) + + losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + # compute the gradients + optimizer.zero_grad() + loss.backward() + + # step + optimizer.step() + scheduler.step() + + # log + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + print("lr", optimizer.param_groups[0]["lr"]) + + if iteration - start_iter > 5: + if iteration % running_checkpoint_period == 0: + torch.cuda.synchronize() + if distributed.is_main_process(): + logger.info("Checkpointing running_checkpoint") + periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + _ = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + prefixstring=f"ITER: {iteration}", + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + torch.cuda.synchronize() + + iteration = iteration + 1 + + val_results_dict = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + return val_results_dict, feature_model, linear_classifiers, iteration + + +def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): + test_dataset = make_dataset( + dataset_str=test_dataset_str, + transform=make_classification_eval_transform(), + ) + test_data_loader = make_data_loader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, + ) + return test_data_loader + + +def test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + best_classifier_on_val, + prefixstring="", + test_class_mappings=[None], +): + results_dict = {} + for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): + logger.info(f"Testing on {test_dataset_str}") + test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) + dataset_results_dict = evaluate_linear_classifiers( + feature_model, + remove_ddp_wrapper(linear_classifiers), + test_data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=class_mapping, + best_classifier_on_val=best_classifier_on_val, + ) + results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] + return results_dict + + +def run_eval_linear( + model, + output_dir, + train_dataset_str, + val_dataset_str, + batch_size, + epochs, + epoch_length, + num_workers, + save_checkpoint_frequency, + eval_period_iterations, + learning_rates, + autocast_dtype, + test_dataset_strs=None, + resume=True, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, +): + seed = 0 + + if test_dataset_strs is None: + test_dataset_strs = [val_dataset_str] + if test_metric_types is None: + test_metric_types = [val_metric_type] * len(test_dataset_strs) + else: + assert len(test_metric_types) == len(test_dataset_strs) + assert len(test_dataset_strs) == len(test_class_mapping_fpaths) + + train_transform = make_classification_train_transform() + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=train_transform, + ) + training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) + sampler_type = SamplerType.SHARDED_INFINITE + # sampler_type = SamplerType.INFINITE + + n_last_blocks_list = [1, 4] + n_last_blocks = max(n_last_blocks_list) + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) + + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + n_last_blocks_list, + learning_rates, + batch_size, + training_num_classes, + ) + + optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) + max_iter = epochs * epoch_length + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + seed=seed, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) + + checkpoint_period = save_checkpoint_frequency * epoch_length + + if val_class_mapping_fpath is not None: + logger.info(f"Using class mapping from {val_class_mapping_fpath}") + val_class_mapping = np.load(val_class_mapping_fpath) + else: + val_class_mapping = None + + test_class_mappings = [] + for class_mapping_fpath in test_class_mapping_fpaths: + if class_mapping_fpath is not None and class_mapping_fpath != "None": + logger.info(f"Using class mapping from {class_mapping_fpath}") + class_mapping = np.load(class_mapping_fpath) + else: + class_mapping = None + test_class_mappings.append(class_mapping) + + metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") + val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + optimizer=optimizer, + scheduler=scheduler, + output_dir=output_dir, + max_iter=max_iter, + checkpoint_period=checkpoint_period, + running_checkpoint_period=epoch_length, + eval_period=eval_period_iterations, + metric_type=val_metric_type, + training_num_classes=training_num_classes, + resume=resume, + val_class_mapping=val_class_mapping, + classifier_fpath=classifier_fpath, + ) + results_dict = {} + if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: + results_dict = test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + 0, # num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + val_results_dict["best_classifier"]["name"], + prefixstring="", + test_class_mappings=test_class_mappings, + ) + results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] + results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] + logger.info("Test Results Dict " + str(results_dict)) + + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + run_eval_linear( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + test_dataset_strs=args.test_dataset_strs, + batch_size=args.batch_size, + epochs=args.epochs, + epoch_length=args.epoch_length, + num_workers=args.num_workers, + save_checkpoint_frequency=args.save_checkpoint_frequency, + eval_period_iterations=args.eval_period_iterations, + learning_rates=args.learning_rates, + autocast_dtype=autocast_dtype, + resume=not args.no_resume, + classifier_fpath=args.classifier_fpath, + val_metric_type=args.val_metric_type, + test_metric_types=args.test_metric_types, + val_class_mapping_fpath=args.val_class_mapping_fpath, + test_class_mapping_fpaths=args.test_class_mapping_fpaths, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 linear evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/dinov2/dinov2/eval/log_regression.py b/dinov2/dinov2/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..5f36ec134e0ce25697428a0b3f21cdc2f0145645 --- /dev/null +++ b/dinov2/dinov2/eval/log_regression.py @@ -0,0 +1,444 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging +import sys +import time +from typing import List, Optional + +from cuml.linear_model import LogisticRegression +import torch +import torch.backends.cudnn as cudnn +import torch.distributed +from torch import nn +from torch.utils.data import TensorDataset +from torchmetrics import MetricTracker + +from dinov2.data import make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.distributed import get_global_rank, get_global_size +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import evaluate, extract_features +from dinov2.utils.dtype import as_torch_dtype + + +logger = logging.getLogger("dinov2") + +DEFAULT_MAX_ITER = 1_000 +C_POWER_RANGE = torch.linspace(-6, 5, 45) +_CPU_DEVICE = torch.device("cpu") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--finetune-dataset-str", + dest="finetune_dataset_str", + type=str, + help="Fine-tuning dataset", + ) + parser.add_argument( + "--finetune-on-val", + action="store_true", + help="If there is no finetune dataset, whether to choose the " + "hyperparameters on the val set instead of 10%% of the train dataset", + ) + parser.add_argument( + "--metric-type", + type=MetricType, + choices=list(MetricType), + help="Metric type", + ) + parser.add_argument( + "--train-features-device", + type=str, + help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", + ) + parser.add_argument( + "--train-dtype", + type=str, + help="Data type to convert the train features to (default: %(default)s)", + ) + parser.add_argument( + "--max-train-iters", + type=int, + help="Maximum number of train iterations (default: %(default)s)", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + metric_type=MetricType.MEAN_ACCURACY, + train_features_device="cpu", + train_dtype="float64", + max_train_iters=DEFAULT_MAX_ITER, + finetune_on_val=False, + ) + return parser + + +class LogRegModule(nn.Module): + def __init__( + self, + C, + max_iter=DEFAULT_MAX_ITER, + dtype=torch.float64, + device=_CPU_DEVICE, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.estimator = LogisticRegression( + penalty="l2", + C=C, + max_iter=max_iter, + output_type="numpy", + tol=1e-12, + linesearch_max_iter=50, + ) + + def forward(self, samples, targets): + samples_device = samples.device + samples = samples.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + samples = samples.numpy() + probas = self.estimator.predict_proba(samples) + return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} + + def fit(self, train_features, train_labels): + train_features = train_features.to(dtype=self.dtype, device=self.device) + train_labels = train_labels.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + # both cuML and sklearn only work with numpy arrays on CPU + train_features = train_features.numpy() + train_labels = train_labels.numpy() + self.estimator.fit(train_features, train_labels) + + +def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): + postprocessors = {"metrics": logreg_model} + metrics = {"metrics": logreg_metric} + return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) + + +def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): + logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) + logreg_model.fit(train_features, train_labels) + return logreg_model + + +def train_and_evaluate( + *, + C, + max_iter, + train_features, + train_labels, + logreg_metric, + test_data_loader, + train_dtype=torch.float64, + train_features_device, + eval_device, +): + logreg_model = train_for_C( + C=C, + max_iter=max_iter, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + return evaluate_model( + logreg_model=logreg_model, + logreg_metric=logreg_metric, + test_data_loader=test_data_loader, + device=eval_device, + ) + + +def sweep_C_values( + *, + train_features, + train_labels, + test_data_loader, + metric_type, + num_classes, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + if metric_type == MetricType.PER_CLASS_ACCURACY: + # If we want to output per-class accuracy, we select the hyperparameters with mean per class + metric_type = MetricType.MEAN_PER_CLASS_ACCURACY + logreg_metric = build_metric(metric_type, num_classes=num_classes) + metric_tracker = MetricTracker(logreg_metric, maximize=True) + ALL_C = 10**C_POWER_RANGE + logreg_models = {} + + train_features = train_features.to(dtype=train_dtype, device=train_features_device) + train_labels = train_labels.to(device=train_features_device) + + for i in range(get_global_rank(), len(ALL_C), get_global_size()): + C = ALL_C[i].item() + logger.info( + f"Training for C = {C:.5f}, dtype={train_dtype}, " + f"features: {train_features.shape}, {train_features.dtype}, " + f"labels: {train_labels.shape}, {train_labels.dtype}" + ) + logreg_models[C] = train_for_C( + C=C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + + gather_list = [None for _ in range(get_global_size())] + torch.distributed.all_gather_object(gather_list, logreg_models) + + logreg_models_gathered = {} + for logreg_dict in gather_list: + logreg_models_gathered.update(logreg_dict) + + for i in range(len(ALL_C)): + metric_tracker.increment() + C = ALL_C[i].item() + evals = evaluate_model( + logreg_model=logreg_models_gathered[C], + logreg_metric=metric_tracker, + test_data_loader=test_data_loader, + device=torch.cuda.current_device(), + ) + logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") + + best_stats, which_epoch = metric_tracker.best_metric(return_step=True) + best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} + if which_epoch["top-1"] == i: + best_C = C + logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") + + return best_stats, best_C + + +def eval_log_regression( + *, + model, + train_dataset, + val_dataset, + finetune_dataset, + metric_type, + batch_size, + num_workers, + finetune_on_val=False, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + """ + Implements the "standard" process for log regression evaluation: + The value of C is chosen by training on train_dataset and evaluating on + finetune_dataset. Then, the final model is trained on a concatenation of + train_dataset and finetune_dataset, and is evaluated on val_dataset. + If there is no finetune_dataset, the value of C is the one that yields + the best results on a random 10% subset of the train dataset + """ + + start = time.time() + + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_features, val_labels = extract_features( + model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_data_loader = torch.utils.data.DataLoader( + TensorDataset(val_features, val_labels), + batch_size=batch_size, + drop_last=False, + num_workers=0, + persistent_workers=False, + ) + + if finetune_dataset is None and finetune_on_val: + logger.info("Choosing hyperparameters on the val dataset") + finetune_features, finetune_labels = val_features, val_labels + elif finetune_dataset is None and not finetune_on_val: + logger.info("Choosing hyperparameters on 10% of the train dataset") + torch.manual_seed(0) + indices = torch.randperm(len(train_features), device=train_features.device) + finetune_index = indices[: len(train_features) // 10] + train_index = indices[len(train_features) // 10 :] + finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] + train_features, train_labels = train_features[train_index], train_labels[train_index] + else: + logger.info("Choosing hyperparameters on the finetune dataset") + finetune_features, finetune_labels = extract_features( + model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + # release the model - free GPU memory + del model + gc.collect() + torch.cuda.empty_cache() + finetune_data_loader = torch.utils.data.DataLoader( + TensorDataset(finetune_features, finetune_labels), + batch_size=batch_size, + drop_last=False, + ) + + if len(train_labels.shape) > 1: + num_classes = train_labels.shape[1] + else: + num_classes = train_labels.max() + 1 + + logger.info("Using cuML for logistic regression") + + best_stats, best_C = sweep_C_values( + train_features=train_features, + train_labels=train_labels, + test_data_loader=finetune_data_loader, + metric_type=metric_type, + num_classes=num_classes, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + if not finetune_on_val: + logger.info("Best parameter found, concatenating features") + train_features = torch.cat((train_features, finetune_features)) + train_labels = torch.cat((train_labels, finetune_labels)) + + logger.info("Training final model") + logreg_metric = build_metric(metric_type, num_classes=num_classes) + evals = train_and_evaluate( + C=best_C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + logreg_metric=logreg_metric.clone(), + test_data_loader=val_data_loader, + eval_device=torch.cuda.current_device(), + train_dtype=train_dtype, + train_features_device=train_features_device, + ) + + best_stats = evals[1]["metrics"] + + best_stats["best_C"] = best_C + + logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") + return best_stats + + +def eval_log_regression_with_model( + model, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + autocast_dtype=torch.float, + finetune_on_val=False, + metric_type=MetricType.MEAN_ACCURACY, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + cudnn.benchmark = True + + transform = make_classification_eval_transform(resize_size=224) + target_transform = None + + train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) + val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) + if finetune_dataset_str is not None: + finetune_dataset = make_dataset( + dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform + ) + else: + finetune_dataset = None + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_logreg = eval_log_regression( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + finetune_dataset=finetune_dataset, + metric_type=metric_type, + batch_size=256, + num_workers=0, # 5, + finetune_on_val=finetune_on_val, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + results_dict = { + "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, + "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, + "best_C": results_dict_logreg["best_C"], + } + logger.info( + "\n".join( + [ + "Training of the supervised logistic regression on frozen features completed.\n" + "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), + "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), + "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), + ] + ) + ) + + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_log_regression_with_model( + model=model, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + finetune_dataset_str=args.finetune_dataset_str, + autocast_dtype=autocast_dtype, + finetune_on_val=args.finetune_on_val, + metric_type=args.metric_type, + train_dtype=as_torch_dtype(args.train_dtype), + train_features_device=torch.device(args.train_features_device), + max_train_iters=args.max_train_iters, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 logistic regression evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/dinov2/dinov2/eval/metrics.py b/dinov2/dinov2/eval/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..52be81a859dddde82da93c3657c35352d2bb0a48 --- /dev/null +++ b/dinov2/dinov2/eval/metrics.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from typing import Any, Dict, Optional + +import torch +from torch import Tensor +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.utilities.data import dim_zero_cat, select_topk + + +logger = logging.getLogger("dinov2") + + +class MetricType(Enum): + MEAN_ACCURACY = "mean_accuracy" + MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" + PER_CLASS_ACCURACY = "per_class_accuracy" + IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" + + @property + def accuracy_averaging(self): + return getattr(AccuracyAveraging, self.name, None) + + def __str__(self): + return self.value + + +class AccuracyAveraging(Enum): + MEAN_ACCURACY = "micro" + MEAN_PER_CLASS_ACCURACY = "macro" + PER_CLASS_ACCURACY = "none" + + def __str__(self): + return self.value + + +def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): + if metric_type.accuracy_averaging is not None: + return build_topk_accuracy_metric( + average_type=metric_type.accuracy_averaging, + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: + return build_topk_imagenet_real_accuracy_metric( + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + + raise ValueError(f"Unknown metric type {metric_type}") + + +def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} + return MetricCollection(metrics) + + +class ImageNetReaLAccuracy(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.top_k = top_k + self.add_state("tp", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + # preds [B, D] + # target [B, A] + # preds_oh [B, D] with 0 and 1 + # select top K highest probabilities, use one hot representation + preds_oh = select_topk(preds, self.top_k) + # target_oh [B, D + 1] with 0 and 1 + target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) + target = target.long() + # for undefined targets (-1) use a fake value `num_classes` + target[target == -1] = self.num_classes + # fill targets, use one hot representation + target_oh.scatter_(1, target, 1) + # target_oh [B, D] (remove the fake target at index `num_classes`) + target_oh = target_oh[:, :-1] + # tp [B] with 0 and 1 + tp = (preds_oh * target_oh == 1).sum(dim=1) + # at least one match between prediction and target + tp.clip_(max=1) + # ignore instances where no targets are defined + mask = target_oh.sum(dim=1) > 0 + tp = tp[mask] + self.tp.append(tp) # type: ignore + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp) # type: ignore + return tp.float().mean() diff --git a/dinov2/dinov2/eval/segmentation/__init__.py b/dinov2/dinov2/eval/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/eval/segmentation/hooks/__init__.py b/dinov2/dinov2/eval/segmentation/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..738cc2d2069521ea0353acd0cb0a03e3ddf1fa51 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/hooks/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .optimizer import DistOptimizerHook diff --git a/dinov2/dinov2/eval/segmentation/hooks/optimizer.py b/dinov2/dinov2/eval/segmentation/hooks/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f593f26a84475bbf7ebda9607a4d10914b13a443 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/hooks/optimizer.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +try: + import apex +except ImportError: + print("apex is not installed") + +from mmcv.runner import OptimizerHook, HOOKS + + +@HOOKS.register_module() +class DistOptimizerHook(OptimizerHook): + """Optimizer hook for distributed training.""" + + def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.update_interval = update_interval + self.use_fp16 = use_fp16 + + def before_run(self, runner): + runner.optimizer.zero_grad() + + def after_train_iter(self, runner): + runner.outputs["loss"] /= self.update_interval + if self.use_fp16: + # runner.outputs['loss'].backward() + with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: + scaled_loss.backward() + else: + runner.outputs["loss"].backward() + if self.every_n_iters(runner, self.update_interval): + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() + runner.optimizer.zero_grad() diff --git a/dinov2/dinov2/eval/segmentation/models/__init__.py b/dinov2/dinov2/eval/segmentation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88e4563d4c162d67e7900955a06bd9248d4c9a48 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .decode_heads import * # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py b/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py b/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e9753ae92a36be52f100e3004cbeeff777d14a --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule +from mmseg.models.builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() diff --git a/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py b/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c55317875262dadf8970c2b3882f016b8d4731ac --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .linear_head import BNHead diff --git a/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py b/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f39c68fb136f84d1aa5284da5b69581bb177cc --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.ops import resize + + +@HEADS.register_module() +class BNHead(BaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, resize_factors=None, **kwargs): + super().__init__(**kwargs) + assert self.in_channels == self.channels + self.bn = nn.SyncBatchNorm(self.in_channels) + self.resize_factors = resize_factors + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # print("inputs", [i.shape for i in inputs]) + x = self._transform_inputs(inputs) + # print("x", x.shape) + feats = self.bn(x) + # print("feats", feats.shape) + return feats + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == "resize_concat": + # accept lists (for cls token) + input_list = [] + for x in inputs: + if isinstance(x, list): + input_list.extend(x) + else: + input_list.append(x) + inputs = input_list + # an image descriptor can be a local descriptor with resolution 1x1 + for i, x in enumerate(inputs): + if len(x.shape) == 2: + inputs[i] = x[:, :, None, None] + # select indices + inputs = [inputs[i] for i in self.in_index] + # Resizing shenanigans + # print("before", *(x.shape for x in inputs)) + if self.resize_factors is not None: + assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs)) + inputs = [ + resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area") + for x, f in zip(inputs, self.resize_factors) + ] + # print("after", *(x.shape for x in inputs)) + upsampled_inputs = [ + resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) + for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/dinov2/dinov2/eval/segmentation/utils/__init__.py b/dinov2/dinov2/eval/segmentation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/eval/segmentation/utils/colormaps.py b/dinov2/dinov2/eval/segmentation/utils/colormaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ef604b2c75792e95e438abfd51ab03d40de340 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation/utils/colormaps.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +ADE20K_COLORMAP = [ + (0, 0, 0), + (120, 120, 120), + (180, 120, 120), + (6, 230, 230), + (80, 50, 50), + (4, 200, 3), + (120, 120, 80), + (140, 140, 140), + (204, 5, 255), + (230, 230, 230), + (4, 250, 7), + (224, 5, 255), + (235, 255, 7), + (150, 5, 61), + (120, 120, 70), + (8, 255, 51), + (255, 6, 82), + (143, 255, 140), + (204, 255, 4), + (255, 51, 7), + (204, 70, 3), + (0, 102, 200), + (61, 230, 250), + (255, 6, 51), + (11, 102, 255), + (255, 7, 71), + (255, 9, 224), + (9, 7, 230), + (220, 220, 220), + (255, 9, 92), + (112, 9, 255), + (8, 255, 214), + (7, 255, 224), + (255, 184, 6), + (10, 255, 71), + (255, 41, 10), + (7, 255, 255), + (224, 255, 8), + (102, 8, 255), + (255, 61, 6), + (255, 194, 7), + (255, 122, 8), + (0, 255, 20), + (255, 8, 41), + (255, 5, 153), + (6, 51, 255), + (235, 12, 255), + (160, 150, 20), + (0, 163, 255), + (140, 140, 140), + (250, 10, 15), + (20, 255, 0), + (31, 255, 0), + (255, 31, 0), + (255, 224, 0), + (153, 255, 0), + (0, 0, 255), + (255, 71, 0), + (0, 235, 255), + (0, 173, 255), + (31, 0, 255), + (11, 200, 200), + (255, 82, 0), + (0, 255, 245), + (0, 61, 255), + (0, 255, 112), + (0, 255, 133), + (255, 0, 0), + (255, 163, 0), + (255, 102, 0), + (194, 255, 0), + (0, 143, 255), + (51, 255, 0), + (0, 82, 255), + (0, 255, 41), + (0, 255, 173), + (10, 0, 255), + (173, 255, 0), + (0, 255, 153), + (255, 92, 0), + (255, 0, 255), + (255, 0, 245), + (255, 0, 102), + (255, 173, 0), + (255, 0, 20), + (255, 184, 184), + (0, 31, 255), + (0, 255, 61), + (0, 71, 255), + (255, 0, 204), + (0, 255, 194), + (0, 255, 82), + (0, 10, 255), + (0, 112, 255), + (51, 0, 255), + (0, 194, 255), + (0, 122, 255), + (0, 255, 163), + (255, 153, 0), + (0, 255, 10), + (255, 112, 0), + (143, 255, 0), + (82, 0, 255), + (163, 255, 0), + (255, 235, 0), + (8, 184, 170), + (133, 0, 255), + (0, 255, 92), + (184, 0, 255), + (255, 0, 31), + (0, 184, 255), + (0, 214, 255), + (255, 0, 112), + (92, 255, 0), + (0, 224, 255), + (112, 224, 255), + (70, 184, 160), + (163, 0, 255), + (153, 0, 255), + (71, 255, 0), + (255, 0, 163), + (255, 204, 0), + (255, 0, 143), + (0, 255, 235), + (133, 255, 0), + (255, 0, 235), + (245, 0, 255), + (255, 0, 122), + (255, 245, 0), + (10, 190, 212), + (214, 255, 0), + (0, 204, 255), + (20, 0, 255), + (255, 255, 0), + (0, 153, 255), + (0, 41, 255), + (0, 255, 204), + (41, 0, 255), + (41, 255, 0), + (173, 0, 255), + (0, 245, 255), + (71, 0, 255), + (122, 0, 255), + (0, 255, 184), + (0, 92, 255), + (184, 255, 0), + (0, 133, 255), + (255, 214, 0), + (25, 194, 194), + (102, 255, 0), + (92, 0, 255), +] + +ADE20K_CLASS_NAMES = [ + "", + "wall", + "building;edifice", + "sky", + "floor;flooring", + "tree", + "ceiling", + "road;route", + "bed", + "windowpane;window", + "grass", + "cabinet", + "sidewalk;pavement", + "person;individual;someone;somebody;mortal;soul", + "earth;ground", + "door;double;door", + "table", + "mountain;mount", + "plant;flora;plant;life", + "curtain;drape;drapery;mantle;pall", + "chair", + "car;auto;automobile;machine;motorcar", + "water", + "painting;picture", + "sofa;couch;lounge", + "shelf", + "house", + "sea", + "mirror", + "rug;carpet;carpeting", + "field", + "armchair", + "seat", + "fence;fencing", + "desk", + "rock;stone", + "wardrobe;closet;press", + "lamp", + "bathtub;bathing;tub;bath;tub", + "railing;rail", + "cushion", + "base;pedestal;stand", + "box", + "column;pillar", + "signboard;sign", + "chest;of;drawers;chest;bureau;dresser", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace;hearth;open;fireplace", + "refrigerator;icebox", + "grandstand;covered;stand", + "path", + "stairs;steps", + "runway", + "case;display;case;showcase;vitrine", + "pool;table;billiard;table;snooker;table", + "pillow", + "screen;door;screen", + "stairway;staircase", + "river", + "bridge;span", + "bookcase", + "blind;screen", + "coffee;table;cocktail;table", + "toilet;can;commode;crapper;pot;potty;stool;throne", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove;kitchen;stove;range;kitchen;range;cooking;stove", + "palm;palm;tree", + "kitchen;island", + "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", + "swivel;chair", + "boat", + "bar", + "arcade;machine", + "hovel;hut;hutch;shack;shanty", + "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle", + "towel", + "light;light;source", + "truck;motortruck", + "tower", + "chandelier;pendant;pendent", + "awning;sunshade;sunblind", + "streetlight;street;lamp", + "booth;cubicle;stall;kiosk", + "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", + "airplane;aeroplane;plane", + "dirt;track", + "apparel;wearing;apparel;dress;clothes", + "pole", + "land;ground;soil", + "bannister;banister;balustrade;balusters;handrail", + "escalator;moving;staircase;moving;stairway", + "ottoman;pouf;pouffe;puff;hassock", + "bottle", + "buffet;counter;sideboard", + "poster;posting;placard;notice;bill;card", + "stage", + "van", + "ship", + "fountain", + "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter", + "canopy", + "washer;automatic;washer;washing;machine", + "plaything;toy", + "swimming;pool;swimming;bath;natatorium", + "stool", + "barrel;cask", + "basket;handbasket", + "waterfall;falls", + "tent;collapsible;shelter", + "bag", + "minibike;motorbike", + "cradle", + "oven", + "ball", + "food;solid;food", + "step;stair", + "tank;storage;tank", + "trade;name;brand;name;brand;marque", + "microwave;microwave;oven", + "pot;flowerpot", + "animal;animate;being;beast;brute;creature;fauna", + "bicycle;bike;wheel;cycle", + "lake", + "dishwasher;dish;washer;dishwashing;machine", + "screen;silver;screen;projection;screen", + "blanket;cover", + "sculpture", + "hood;exhaust;hood", + "sconce", + "vase", + "traffic;light;traffic;signal;stoplight", + "tray", + "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin", + "fan", + "pier;wharf;wharfage;dock", + "crt;screen", + "plate", + "monitor;monitoring;device", + "bulletin;board;notice;board", + "shower", + "radiator", + "glass;drinking;glass", + "clock", + "flag", +] + + +VOC2012_COLORMAP = [ + (0, 0, 0), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (128, 0, 128), + (0, 128, 128), + (128, 128, 128), + (64, 0, 0), + (192, 0, 0), + (64, 128, 0), + (192, 128, 0), + (64, 0, 128), + (192, 0, 128), + (64, 128, 128), + (192, 128, 128), + (0, 64, 0), + (128, 64, 0), + (0, 192, 0), + (128, 192, 0), + (0, 64, 128), +] + + +VOC2012_CLASS_NAMES = [ + "", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] diff --git a/dinov2/dinov2/eval/segmentation_m2f/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c678fdf8f1dee14d7cf9be70af14e6f9a1441c3 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .core import * # noqa: F403 +from .models import * # noqa: F403 +from .ops import * # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92599806fbd221c1418d179892a0f46dc0b7d4db --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmseg.core.evaluation import * # noqa: F403 +from mmseg.core.seg import * # noqa: F403 + +from .anchor import * # noqa: F403 +from .box import * # noqa: F403 +from .utils import * # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e71ac4d6e01462221ae01aa16d0e1231cda7e2e7 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .point_generator import MlvlPointGenerator # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6dba90e22de76d2f23a86d3c057f196d55a99690 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.utils import Registry, build_from_cfg + +PRIOR_GENERATORS = Registry("Generator for anchors and points") + +ANCHOR_GENERATORS = PRIOR_GENERATORS + + +def build_prior_generator(cfg, default_args=None): + return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") + return build_prior_generator(cfg, default_args=default_args) diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..574d71939080e22284fe99087fb2e7336657bd97 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .builder import PRIOR_GENERATORS + + +@PRIOR_GENERATORS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride + ) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_sizes, pad_shape, device="cuda"): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (obj:`torch.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf35a613f81acd77ecab2dfb75a722fa8e5c0787 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .builder import * # noqa: F403 +from .samplers import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9538c0de3db682c2b111b085a8a1ce321c76a9ff --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry, build_from_cfg + +BBOX_SAMPLERS = Registry("bbox_sampler") +BBOX_CODERS = Registry("bbox_coder") + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19c363e3fabc365d92aeaf1e78189d710db279e9 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c45cec3ed7af5b49bb54b92d6e6bcf59b06b4c99 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers.""" + + def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative samples.""" + pass + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) + """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError("gt_labels must be given when add_gt_as_proposals is True") + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) + return sampling_result diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3e67ea61ed0fd65cca0addde1893a3c1e176bf15 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from ..builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..270ffd35a5f120dd0560a7fea7fe83ef0bab66bb --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data["pos_masks"] = data.pop("pos_masks").shape + data["neg_masks"] = data.pop("neg_masks").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_masks": self.pos_masks, + "neg_masks": self.neg_masks, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..aaee3fe55aeb8c6da7edefbbd382d94b67b6a6b4 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch + + +class SamplingResult: + """Bbox sampling result. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data["pos_bboxes"] = data.pop("pos_bboxes").shape + data["neg_bboxes"] = data.pop("neg_bboxes").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_bboxes": self.pos_bboxes, + "neg_bboxes": self.neg_bboxes, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: number of predicted boxes + - num_gts: number of true boxes + - p_ignore (float): probability of a predicted box assigned to \ + an ignored truth. + - p_assigned (float): probability of a predicted box not being \ + assigned. + - p_use_label (float | bool): with labels or not. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox import demodata + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler( + num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng + ) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdc9e19352f50bc2d5433c412ff71186c5df019 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dist_utils import reduce_mean +from .misc import add_prefix, multi_apply diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dfed42da821cd94e31b663d86b20b8f09799b30 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch.distributed as dist + + +def reduce_mean(tensor): + """ "Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor diff --git a/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py b/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e07579e7b182b62153e81fe637ffd0f3081ef2a3 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed89bb0064d82b4360af020798eab3d2f5a47937 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost +from .decode_heads import * # noqa: F403 +from .losses import * # noqa: F403 +from .plugins import * # noqa: F403 +from .segmentors import * # noqa: F403 diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf73bcbcee710676f81cb6517ae787f4d61cc6 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vit_adapter import ViTAdapter diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..26bfdf8f6ae6c107d22d61985cce34d4b5ce275f --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ...ops.modules import MSDeformAttn +from .drop_path import DropPath + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x, patch_size): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device + ) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() + x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() + x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0.0, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False, + ): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.0, + with_cp=False, + ): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c + + +class InteractionBlockWithCls(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential( + *[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + self.conv2 = nn.Sequential( + *[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv3 = nn.Sequential( + *[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv4 = nn.Sequential( + *[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..864eb8738c44652d12b979fc811503f21cbb00dd --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8a147570451bd2fbd016ddfafbbfa33035cbd4f8 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +"""Vision Transformer (ViT) in PyTorch. + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +from functools import partial +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import BaseModule, load_checkpoint +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from torch import Tensor + +from .drop_path import DropPath + + +def to_2tuple(x): + return tuple(repeat(x, 2)) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as + self.w1 = nn.Linear(in_features, swiglu_hidden_features) + self.w2 = nn.Linear(in_features, swiglu_hidden_features) + self.w3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, H, W) -> Tensor: + from xformers.ops import memory_efficient_attention, unbind + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowedAttention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant" + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.window_size = window_size + self.pad_mode = pad_mode + + def forward(self, x, H, W): + B, N, C = x.shape + N_ = self.window_size * self.window_size + H_ = math.ceil(H / self.window_size) * self.window_size + W_ = math.ceil(W / self.window_size) * self.window_size + + qkv = self.qkv(x) # [B, N, C] + qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] + qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) + + qkv = F.unfold( + qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size) + ) + B, C_kw_kw, L = qkv.shape # L - the num of windows + qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] + qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # q,k,v [B, L, num_head, N_, C/num_head] + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] + # if self.mask: + # attn = attn * mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] + # attn @ v = [B, L, num_head, N_, C/num_head] + x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) + + x = F.fold( + x, + output_size=(H_, W_), + kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size), + ) # [B, C, H_, W_] + x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class WindowedAttention(nn.Module): +# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): +# super().__init__() +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = head_dim ** -0.5 +# +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) +# self.window_size = window_size +# self.pad_mode = pad_mode +# +# def forward(self, x, H, W): +# B, N, C = x.shape +# +# N_ = self.window_size * self.window_size +# H_ = math.ceil(H / self.window_size) * self.window_size +# W_ = math.ceil(W / self.window_size) * self.window_size +# x = x.view(B, H, W, C) +# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) +# +# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C +# x = x.view(-1, N_, C) +# +# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) +# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] +# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) +# +# x = window_reverse(x, self.window_size, H_, W_) +# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() +# x = self.proj(x) +# x = self.proj_drop(x) +# return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + windowed=False, + window_size=14, + pad_mode="constant", + layer_scale=False, + with_cp=False, + ffn_layer=Mlp, + memeff=False, + ): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + if windowed: + self.attn = WindowedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + pad_mode=pad_mode, + ) + elif memeff: + self.attn = MemEffAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + else: + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + + def forward(self, x, H, W): + def _inner_forward(x): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class TIMMVisionTransformer(BaseModule): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + layer_scale=True, + embed_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + window_attn=False, + window_size=14, + pretrained=None, + with_cp=False, + pre_norm=False, + ffn_type="mlp", + memeff=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + self.patch_size = patch_size + + window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn + window_size = [window_size] * depth if not isinstance(window_size, list) else window_size + logging.info("window attention:", window_attn) + logging.info("window size:", window_size) + logging.info("layer scale:", layer_scale) + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm + ) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN} + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + windowed=window_attn[i], + window_size=window_size[i], + layer_scale=layer_scale, + with_cp=with_cp, + ffn_layer=ffn_types[ffn_type], + memeff=memeff, + ) + for i in range(depth) + ] + ) + + # self.norm = norm_layer(embed_dim) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # For CLIP + if pre_norm: + norm_pre = norm_layer(embed_dim) + self.norm_pre = norm_pre + else: + self.norm_pre = nn.Identity() + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger) + + def forward_features(self, x): + x, H, W = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + # For CLIP + x = self.norm_pre(x) + + for blk in self.blocks: + x = blk(x, H, W) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] + pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc4f0f65e04ed764464d141607b3b2073220f6b --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import BACKBONES +from torch.nn.init import normal_ + +from ...ops.modules import MSDeformAttn +from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs +from .vit import TIMMVisionTransformer + + +@BACKBONES.register_module() +class ViTAdapter(TIMMVisionTransformer): + def __init__( + self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0.0, + interaction_indexes=None, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=1.0, + add_vit_feature=True, + pretrained=None, + use_extra_extractor=True, + freeze_vit=False, + use_cls=True, + with_cp=False, + *args, + **kwargs + ): + + super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs) + if freeze_vit: + for param in self.parameters(): + param.requires_grad = False + + # self.num_classes = 80 + self.use_cls = use_cls + if not self.use_cls: + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + block_fn = InteractionBlockWithCls if use_cls else InteractionBlock + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor), + with_cp=with_cp, + ) + for i in range(len(interaction_indexes)) + ] + ) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 + ).permute(0, 3, 1, 2) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + ) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 + x, H_toks, W_toks = self.patch_embed(x) + # print("H_toks, W_toks =", H_toks, W_toks) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks) + if self.use_cls: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1) + x = self.pos_drop(x + pos_embed) + # For CLIP + x = self.norm_pre(x) + + # Interaction + if self.use_cls: + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + if self.use_cls: + x, c, cls = layer( + x, + c, + cls, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + else: + x, c = layer( + x, + c, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) + + # Split & Reshape + c2 = c[:, 0 : c2.size(1), :] + c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1) :, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + + x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) + x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) + x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) + x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) + # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/builder.py b/dinov2/dinov2/eval/segmentation_m2f/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cf7b919f6b0e8e00bde45bc244d9c29a36fed6 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry + +TRANSFORMER = Registry("Transformer") +MASK_ASSIGNERS = Registry("mask_assigner") +MATCH_COST = Registry("match_cost") + + +def build_match_cost(cfg): + """Build Match Cost.""" + return MATCH_COST.build(cfg) + + +def build_assigner(cfg): + """Build Assigner.""" + return MASK_ASSIGNERS.build(cfg) + + +def build_transformer(cfg): + """Build Transformer.""" + return TRANSFORMER.build(cfg) diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01f08b88950750337781fc671adfea2a935ea8fe --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask2former_head import Mask2FormerHead diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1705fc444fa8d1583d88fca36d7fe1e060db9e7 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py @@ -0,0 +1,544 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import build_sampler, multi_apply, reduce_mean +from ..builder import build_assigner +from ..utils import get_uncertain_point_coords_with_randomness + + +@HEADS.register_module() +class Mask2FormerHead(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs, + ): + super(Mask2FormerHead, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform="multiple_select", + **kwargs, + ) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project: + self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels), + ) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get("num_points", 12544) + self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0) + self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries,)) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio + ) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=("all_cls_scores", "all_mask_preds")) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list + ) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_mask"] = losses_mask[-1] + loss_dict["loss_dice"] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i + loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature) + attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding(mask) + decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + ) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:] + ) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]["ori_shape"] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred) + return seg_mask diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..229a887817372f4991b32354180592cfb236d728 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy +from .dice_loss import DiceLoss +from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1f9dd4aa52ebe94cc527db36b1c7fa2f53813e --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss + + +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False, +): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred, + label, + weight=None, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs, +): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" + pred = pred.squeeze() + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), ( + "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported" + ) + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == "mean" and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none") + # do the reduction for the weighted loss + loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy( + pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs +): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, "BCE loss does not support ignore_index" + assert reduction == "mean" and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None] + + +@LOSSES.register_module(force=True) +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__( + self, + use_sigmoid=False, + use_mask=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + avg_non_ignore=False, + ): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == "mean": + warnings.warn( + "Default ``avg_non_ignore`` is False, if you would like to " + "ignore the certain label and average loss over non-ignore " + "labels, which is the same with PyTorch official " + "cross_entropy, set ``avg_non_ignore=True``." + ) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f"avg_non_ignore={self.avg_non_ignore}" + return s + + def forward( + self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs + ): + """Forward function.""" + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs, + ) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc5ba893c502861032ed531283f225e183eb693 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate naive dice loss, the coefficient in the denominator is the + first power instead of the second power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class DiceLoss(nn.Module): + def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3): + """Dice Loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + if self.naive_dice: + loss = self.loss_weight * naive_dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + else: + loss = self.loss_weight * dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + + return loss diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py new file mode 100644 index 0000000000000000000000000000000000000000..4917d2a939c01398dd49c0d90b06f4c37d283ce0 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight=1.0): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + + def __init__(self, weight=1.0, pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + """ + + def __init__(self, weight=1.0, use_sigmoid=True): + assert use_sigmoid, "use_sigmoid = False is not supported yet." + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none") + neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none") + cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a60db4de31238cb38e078683e5ca265839fe60 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..db1947175917f73f3f24184cb09c78e092d46ef8 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.runner import BaseModule, ModuleList + +from ...core.anchor import MlvlPointGenerator +from ..utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__( + self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type="GN", num_groups=32), + act_cfg=dict(type="ReLU"), + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict( + type="MultiScaleDeformableAttention", + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None, + ), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + init_cfg=None, + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one" + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1): + input_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True + ) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding(positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + lateral_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None + ) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform") + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device + ) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios, + ) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[: self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adf0062691e4889612e118f28ced853cd0bc33db --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .encoder_decoder_mask2former import EncoderDecoderMask2Former diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe572c9d317303bff8d51b85217d144906ebfe7 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__( + self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, f"aux_{idx}")) + else: + loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, "aux")) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]["ori_shape"][:2], + mode="bilinear", + align_corners=self.align_corners, + warning=False, + ) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]["ori_shape"][:2] + seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fdc1668b1015c8feea8fa1a4691bc0ebdbd936 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .assigner import MaskHungarianAssigner +from .point_sample import get_uncertain_point_coords_with_randomness +from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding +from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb08fc1bb2e36336989b45a1d3850f260c05963 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from ..builder import MASK_ASSIGNERS, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +class AssignResult(metaclass=ABCMeta): + """Collection of assign results.""" + + def __init__(self, num_gts, gt_inds, labels): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.labels = labels + + @property + def info(self): + info = { + "num_gts": self.num_gts, + "gt_inds": self.gt_inds, + "labels": self.labels, + } + return info + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + + @abstractmethod + def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): + """Assign boxes to either a ground truth boxes or a negative boxes.""" + pass + + +@MASK_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. + """ + + def __init__( + self, + cls_cost=dict(type="ClassificationCost", weight=1.0), + dice_cost=dict(type="DiceCost", weight=1.0), + mask_cost=dict(type="MaskFocalCost", weight=1.0), + ): + self.cls_cost = build_match_cost(cls_cost) + self.dice_cost = build_match_cost(dice_cost) + self.mask_cost = build_match_cost(mask_cost) + + def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + mask_pred (Tensor): Predicted mask, shape [num_query, h, w] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. + gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_masks_ignore (Tensor, optional): Ground truth masks that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." + num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + if num_gts == 0 or num_queries == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [nq, h, w] + # gt_mask shape = [ng, h, w] + # mask_cost shape = [nq, ng] + mask_cost = self.mask_cost(mask_pred, gt_masks) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_masks) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' "to install scipy first.") + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1134082bafb51432618a9632592db070f87284 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +from mmcv.ops import point_sample + + +def get_uncertainty(mask_pred, labels): + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_pred' for the foreground class in `classes`. + + Args: + mask_pred (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (list[Tensor]): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_pred.shape[1] == 1: + gt_class_logits = mask_pred.clone() + else: + inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) + gt_class_logits = mask_pred[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio +): + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_pred (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (list): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_pred.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) + point_logits = point_sample(mask_pred, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5d6fabe946d06fe97cc799da47bae93758b34e --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING +from mmcv.runner import BaseModule + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__( + self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None + ): + super(SinePositionalEncoding, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), ( + "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}" + ) + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, H, W = mask.size() + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"temperature={self.temperature}, " + repr_str += f"normalize={self.normalize}, " + repr_str += f"scale={self.scale}, " + repr_str += f"eps={self.eps})" + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")): + super(LearnedPositionalEncoding, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = ( + torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(mask.shape[0], 1, 1, 1) + ) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"row_num_embed={self.row_num_embed}, " + repr_str += f"col_num_embed={self.col_num_embed})" + return repr_str diff --git a/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8befe6011a34d5ccecb82c8b17b61e19f732f96b --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py @@ -0,0 +1,989 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE +from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence +from mmcv.runner.base_module import BaseModule, Sequential +from mmcv.utils import deprecated_api_warning, to_2tuple +from torch.nn.init import normal_ + +from ..builder import TRANSFORMER + +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention + +except ImportError: + warnings.warn( + "`MultiScaleDeformableAttention` in MMCV has been moved to " + "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV" + ) + from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + + super(AdaptivePadding, self).__init__() + + assert padding in ("same", "corner") + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding="corner", + dilation=1, + bias=False, + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" + + H, W = input_size + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = ( + H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 + ) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +@FEEDFORWARD_NETWORK.register_module(force=True) +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type="ReLU", inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True, + init_cfg=None, + with_cp=False, + **kwargs, + ): + super().__init__(init_cfg) + assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + self.with_cp = with_cp + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") + def forward(self, x, identity=None): + """Forward function for `FFN`. + The function would add x to the output tensor if residue is None. + """ + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.layers, x) + else: + out = self.layers(x) + + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__( + self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + ffn_num_fcs=2, + **kwargs, + ): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs, + ) + assert len(operation_order) == 6 + assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"]) + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg" + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs): + + super(DetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) + + +@TRANSFORMER.register_module() +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, "weight") and m.weight.dim() > 1: + xavier_init(m, distribution="uniform") + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + output = layer(output, *args, reference_points=reference_points_input, **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +@TRANSFORMER.register_module() +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution="uniform", bias=0.0) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += H * W + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def forward( + self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs + ): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs, + ) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs, + ) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return inter_states, init_reference_out, inter_references_out, None, None + + +@TRANSFORMER.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=7, + with_proj=True, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature, input_feature): + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49aa8fe612fd4c088e294707c5ee16bd1cb5b5e7 --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules +# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 + +from .ms_deform_attn import MSDeformAttn diff --git a/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b4fa23712e87d1a2682b57e71ee37fe8524cff --- /dev/null +++ b/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd +from torch.nn.init import constant_, xavier_uniform_ + + +class MSDeformAttnFunction(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + output = ms_deform_attn_core_pytorch( + value, + value_spatial_shapes, + # value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make " + "the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + # print(query.shape) + # print(reference_points.shape) + # print(input_flatten.shape) + # print(input_spatial_shapes.shape) + # print(input_level_start_index.shape) + # print(input_spatial_shapes) + # print(input_level_start_index) + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/dinov2/dinov2/eval/setup.py b/dinov2/dinov2/eval/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..959128c0673cc51036dbf17dcc4ee68a037988fb --- /dev/null +++ b/dinov2/dinov2/eval/setup.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from typing import Any, List, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn + +from dinov2.models import build_model_from_cfg +from dinov2.utils.config import setup +import dinov2.utils.utils as dinov2_utils + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parser = argparse.ArgumentParser( + description=description, + parents=parents or [], + add_help=add_help, + ) + parser.add_argument( + "--config-file", + type=str, + help="Model configuration file", + ) + parser.add_argument( + "--pretrained-weights", + type=str, + help="Pretrained model weights", + ) + parser.add_argument( + "--output-dir", + default="", + type=str, + help="Output directory to write results and logs", + ) + parser.add_argument( + "--opts", + help="Extra configuration options", + default=[], + nargs="+", + ) + return parser + + +def get_autocast_dtype(config): + teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype + if teacher_dtype_str == "fp16": + return torch.half + elif teacher_dtype_str == "bf16": + return torch.bfloat16 + else: + return torch.float + + +def build_model_for_eval(config, pretrained_weights): + model, _ = build_model_from_cfg(config, only_teacher=True) + dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") + model.eval() + model.cuda() + return model + + +def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: + cudnn.benchmark = True + config = setup(args) + model = build_model_for_eval(config, args.pretrained_weights) + autocast_dtype = get_autocast_dtype(config) + return model, autocast_dtype diff --git a/dinov2/dinov2/eval/utils.py b/dinov2/dinov2/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c50576b1940587ee64b7a422e2e96b475d60fd39 --- /dev/null +++ b/dinov2/dinov2/eval/utils.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +import torch +from torch import nn +from torchmetrics import MetricCollection + +from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +import dinov2.distributed as distributed +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +class ModelWithNormalize(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, samples): + return nn.functional.normalize(self.model(samples), dim=1, p=2) + + +class ModelWithIntermediateLayers(nn.Module): + def __init__(self, feature_model, n_last_blocks, autocast_ctx): + super().__init__() + self.feature_model = feature_model + self.feature_model.eval() + self.n_last_blocks = n_last_blocks + self.autocast_ctx = autocast_ctx + + def forward(self, images): + with torch.inference_mode(): + with self.autocast_ctx(): + features = self.feature_model.get_intermediate_layers( + images, self.n_last_blocks, return_class_token=True + ) + return features + + +@torch.inference_mode() +def evaluate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, MetricCollection], + device: torch.device, + criterion: Optional[nn.Module] = None, +): + model.eval() + if criterion is not None: + criterion.eval() + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): + outputs = model(samples.to(device)) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return metric_logger_stats, stats + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_global_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) + + +@torch.inference_mode() +def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + features_rank = model(samples).float() + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features shape: {tuple(features.shape)}") + logger.info(f"Labels shape: {tuple(all_labels.shape)}") + + assert torch.all(all_labels > -1) + + return features, all_labels diff --git a/dinov2/dinov2/fsdp/__init__.py b/dinov2/dinov2/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed454480e0b76e761d657cc40fd097bd339d15a2 --- /dev/null +++ b/dinov2/dinov2/fsdp/__init__.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Any + +import torch +import dinov2.distributed as distributed +from functools import partial +from fvcore.common.checkpoint import Checkpointer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp._runtime_utils import _reshard + + +def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): + sharding_strategy_dict = { + "NO_SHARD": ShardingStrategy.NO_SHARD, + "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, + "FULL_SHARD": ShardingStrategy.FULL_SHARD, + } + + dtype_dict = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + mixed_precision_config = MixedPrecision( + param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], + reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], + buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], + ) + + sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] + + local_rank = distributed.get_local_rank() + + fsdp_wrapper = partial( + FSDP, + sharding_strategy=sharding_strategy_config, + mixed_precision=mixed_precision_config, + device_id=local_rank, + sync_module_states=True, + use_orig_params=True, + auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), + ) + return fsdp_wrapper + + +def is_fsdp(x): + return isinstance(x, FSDP) + + +def is_sharded_fsdp(x): + return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD + + +def free_if_fsdp(x): + if is_sharded_fsdp(x): + handles = x._handles + true_list = [True for h in handles] + _reshard(x, handles, true_list) + + +def get_fsdp_modules(x): + return FSDP.fsdp_modules(x) + + +def reshard_fsdp_model(x): + for m in get_fsdp_modules(x): + free_if_fsdp(m) + + +def rankstr(): + return f"rank_{distributed.get_global_rank()}" + + +class FSDPCheckpointer(Checkpointer): + def save(self, name: str, **kwargs: Any) -> None: + """ + Dump model and checkpointables to a file. + + Args: + name (str): name of the file. + kwargs (dict): extra arbitrary data to save. + """ + if not self.save_dir or not self.save_to_disk: + return + + data = {} + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + data["model"] = self.model.state_dict() + + # data["model"] = self.model.state_dict() + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = f"{name}.{rankstr()}.pth" + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving checkpoint to {}".format(save_file)) + with self.path_manager.open(save_file, "wb") as f: + torch.save(data, f) + self.tag_last_checkpoint(basename) + + def load(self, *args, **kwargs): + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + return super().load(*args, **kwargs) + + def has_checkpoint(self) -> bool: + """ + Returns: + bool: whether a checkpoint exists in the target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + return self.path_manager.exists(save_file) + + def get_checkpoint_file(self) -> str: + """ + Returns: + str: The latest checkpoint file in target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + try: + with self.path_manager.open(save_file, "r") as f: + last_saved = f.read().strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + return "" + # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got + # `Union[bytes, str]`. + return os.path.join(self.save_dir, last_saved) + + def tag_last_checkpoint(self, last_filename_basename: str) -> None: + """ + Tag the last checkpoint. + + Args: + last_filename_basename (str): the basename of the last filename. + """ + if distributed.is_enabled(): + torch.distributed.barrier() + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + with self.path_manager.open(save_file, "w") as f: + f.write(last_filename_basename) # pyre-ignore + + +ShardedGradScaler = ShardedGradScaler diff --git a/dinov2/dinov2/hub/__init__.py b/dinov2/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/hub/backbones.py b/dinov2/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/dinov2/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/dinov2/dinov2/hub/classifiers.py b/dinov2/dinov2/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b --- /dev/null +++ b/dinov2/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/dinov2/dinov2/hub/depth/__init__.py b/dinov2/dinov2/hub/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c --- /dev/null +++ b/dinov2/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/dinov2/dinov2/hub/depth/decode_heads.py b/dinov2/dinov2/hub/depth/decode_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da --- /dev/null +++ b/dinov2/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/dinov2/dinov2/hub/depth/encoder_decoder.py b/dinov2/dinov2/hub/depth/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8 --- /dev/null +++ b/dinov2/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/dinov2/dinov2/hub/depth/ops.py b/dinov2/dinov2/hub/depth/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/dinov2/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/dinov2/dinov2/hub/depthers.py b/dinov2/dinov2/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820 --- /dev/null +++ b/dinov2/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/dinov2/dinov2/hub/utils.py b/dinov2/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/dinov2/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/dinov2/dinov2/layers/__init__.py b/dinov2/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/dinov2/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/dinov2/dinov2/layers/attention.py b/dinov2/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777 --- /dev/null +++ b/dinov2/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/dinov2/dinov2/layers/block.py b/dinov2/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..930787b262faac4f2264797496faff75ac56b7cc --- /dev/null +++ b/dinov2/dinov2/layers/block.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/dinov2/dinov2/layers/dino_head.py b/dinov2/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/dinov2/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/dinov2/dinov2/layers/drop_path.py b/dinov2/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/dinov2/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/dinov2/dinov2/layers/layer_scale.py b/dinov2/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/dinov2/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/dinov2/dinov2/layers/mlp.py b/dinov2/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/dinov2/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/dinov2/dinov2/layers/patch_embed.py b/dinov2/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/dinov2/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/dinov2/dinov2/layers/swiglu_ffn.py b/dinov2/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/dinov2/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/dinov2/dinov2/logging/__init__.py b/dinov2/dinov2/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04a7f02204316d4d1ef38bf6080dae3d66241c25 --- /dev/null +++ b/dinov2/dinov2/logging/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import functools +import logging +import os +import sys +from typing import Optional + +import dinov2.distributed as distributed +from .helpers import MetricLogger, SmoothedValue + + +# So that calling _configure_logger multiple times won't add many handlers +@functools.lru_cache() +def _configure_logger( + name: Optional[str] = None, + *, + level: int = logging.DEBUG, + output: Optional[str] = None, +): + """ + Configure a logger. + + Adapted from Detectron2. + + Args: + name: The name of the logger to configure. + level: The logging level to use. + output: A file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + + Returns: + The configured logger. + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # Loosely match Google glog format: + # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg + # but use a shorter timestamp and include the logger name: + # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg + fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " + fmt_message = "%(message)s" + fmt = fmt_prefix + fmt_message + datefmt = "%Y%m%d %H:%M:%S" + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + # stdout logging for main worker only + if distributed.is_main_process(): + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # file logging for all workers + if output: + if os.path.splitext(output)[-1] in (".txt", ".log"): + filename = output + else: + filename = os.path.join(output, "logs", "log.txt") + + if not distributed.is_main_process(): + global_rank = distributed.get_global_rank() + filename = filename + ".rank{}".format(global_rank) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + + handler = logging.StreamHandler(open(filename, "a")) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +def setup_logging( + output: Optional[str] = None, + *, + name: Optional[str] = None, + level: int = logging.DEBUG, + capture_warnings: bool = True, +) -> None: + """ + Setup logging. + + Args: + output: A file name or a directory to save log files. If None, log + files will not be saved. If output ends with ".txt" or ".log", it + is assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + capture_warnings: Whether warnings should be captured as logs. + """ + logging.captureWarnings(capture_warnings) + _configure_logger(name, level=level, output=output) diff --git a/dinov2/dinov2/logging/helpers.py b/dinov2/dinov2/logging/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e70bb15505cbbc4c4732b069ee919bf921a74f --- /dev/null +++ b/dinov2/dinov2/logging/helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict, deque +import datetime +import json +import logging +import time + +import torch + +import dinov2.distributed as distributed + + +logger = logging.getLogger("dinov2") + + +class MetricLogger(object): + def __init__(self, delimiter="\t", output_file=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.output_file = output_file + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def dump_in_output_file(self, iteration, iter_time, data_time): + if self.output_file is None or not distributed.is_main_process(): + return + dict_to_dump = dict( + iteration=iteration, + iter_time=iter_time, + data_time=data_time, + ) + dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + with open(self.output_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") + pass + + def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): + i = start_iteration + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + + if n_iterations is None: + n_iterations = len(iterable) + + space_fmt = ":" + str(len(str(n_iterations))) + "d" + + log_list = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_list += ["max mem: {memory:.0f}"] + + log_msg = self.delimiter.join(log_list) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == n_iterations - 1: + self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + eta_seconds = iter_time.global_avg * (n_iterations - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if i >= n_iterations: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, num=1): + self.deque.append(value) + self.count += num + self.total += value * num + + def synchronize_between_processes(self): + """ + Distributed synchronization of the metric + Warning: does not synchronize the deque! + """ + if not distributed.is_enabled(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) diff --git a/dinov2/dinov2/loss/__init__.py b/dinov2/dinov2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b0115b74edbd74b324c9056a57fade363c58fd --- /dev/null +++ b/dinov2/dinov2/loss/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_clstoken_loss import DINOLoss +from .ibot_patch_loss import iBOTPatchLoss +from .koleo_loss import KoLeoLoss diff --git a/dinov2/dinov2/loss/dino_clstoken_loss.py b/dinov2/dinov2/loss/dino_clstoken_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c31808e36e6c38ee6dae13ba0443bf1946242117 --- /dev/null +++ b/dinov2/dinov2/loss/dino_clstoken_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + + +class DINOLoss(nn.Module): + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + ): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + teacher_output = teacher_output.float() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_output_list, teacher_out_softmaxed_centered_list): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # TODO: Use cross_entropy_distribution here + total_loss = 0 + for s in student_output_list: + lsm = F.log_softmax(s / self.student_temp, dim=-1) + for t in teacher_out_softmaxed_centered_list: + loss = torch.sum(t * lsm, dim=-1) + total_loss -= loss.mean() + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/dinov2/dinov2/loss/ibot_patch_loss.py b/dinov2/dinov2/loss/ibot_patch_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6732cda0c311c69f193669ebc950fc8665871442 --- /dev/null +++ b/dinov2/dinov2/loss/ibot_patch_loss.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +import logging + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import cross_entropy + + def lossfunc(t, s, temp): + s = s.float() + t = t.float() + if s.ndim == 2: + return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) + elif s.ndim == 3: + return -cross_entropy(s, t, temp, bw_inplace=True) + +except ImportError: + + def lossfunc(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class iBOTPatchLoss(nn.Module): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + # + # WARNING: + # as self.center is a float32, everything gets casted to float32 afterwards + # + # teacher_patch_tokens = teacher_patch_tokens.float() + # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) + + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + # this is experimental, keep everything in float16 and let's see what happens: + # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + t = teacher_patch_tokens + s = student_patch_tokens + loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + return -loss.mean() + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + t = teacher_patch_tokens_masked + s = student_patch_tokens_masked + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = lossfunc(t, s, self.student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/dinov2/dinov2/loss/koleo_loss.py b/dinov2/dinov2/loss/koleo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cbcd91e0fc0b857f477b0910f957f02a6c4335 --- /dev/null +++ b/dinov2/dinov2/loss/koleo_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# import torch.distributed as dist + + +logger = logging.getLogger("dinov2") + + +class KoLeoLoss(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + + def pairwise_NNs_inner(self, x): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, x.t()) + n = x.shape[0] + dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 + # max inner prod -> min distance + _, I = torch.max(dots, dim=1) # noqa: E741 + return I + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.cuda.amp.autocast(enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) + I = self.pairwise_NNs_inner(student_output) # noqa: E741 + distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B + loss = -torch.log(distances + eps).mean() + return loss diff --git a/dinov2/dinov2/models/__init__.py b/dinov2/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/dinov2/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/dinov2/dinov2/models/vision_transformer.py b/dinov2/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..939ffd9c7da62986ccea21824fb59b7a03dec309 --- /dev/null +++ b/dinov2/dinov2/models/vision_transformer.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +# from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, out_layer_list=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + if out_layer_list is not None: + x_norm_patchtokens_list = [] + + for i,blk in enumerate(self.blocks): + x = blk(x) + if i + 1 in out_layer_list: + x_norm_patchtokens_list.append(self.norm(x)[:, self.num_register_tokens + 1 :]) + + else: + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + if out_layer_list is not None: + return x_norm_patchtokens_list + else: + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/dinov2/dinov2/run/__init__.py b/dinov2/dinov2/run/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/run/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/run/eval/knn.py b/dinov2/dinov2/run/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..d11918445cdfe415fe58ac8b3ad0bf29702e3457 --- /dev/null +++ b/dinov2/dinov2/run/eval/knn.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.knn import get_args_parser as get_knn_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.knn import main as knn_main + + self._setup_args() + knn_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 k-NN evaluation" + knn_args_parser = get_knn_args_parser(add_help=False) + parents = [knn_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:knn") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/dinov2/run/eval/linear.py b/dinov2/dinov2/run/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e1dc3293e88512a5cf885ab775dc08e01aed6724 --- /dev/null +++ b/dinov2/dinov2/run/eval/linear.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.linear import get_args_parser as get_linear_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.linear import main as linear_main + + self._setup_args() + linear_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 linear evaluation" + linear_args_parser = get_linear_args_parser(add_help=False) + parents = [linear_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:linear") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/dinov2/run/eval/log_regression.py b/dinov2/dinov2/run/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf02181122de72cfa463ef38494967219df9cf3 --- /dev/null +++ b/dinov2/dinov2/run/eval/log_regression.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.log_regression import main as log_regression_main + + self._setup_args() + log_regression_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 logistic evaluation" + log_regression_args_parser = get_log_regression_args_parser(add_help=False) + parents = [log_regression_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:logreg") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/dinov2/run/submit.py b/dinov2/dinov2/run/submit.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1f718e704cf9a48913422404c25a7fcc50e738 --- /dev/null +++ b/dinov2/dinov2/run/submit.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Optional + +import submitit + +from dinov2.utils.cluster import ( + get_slurm_executor_parameters, + get_slurm_partition, + get_user_checkpoint_path, +) + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +) -> argparse.ArgumentParser: + parents = parents or [] + slurm_partition = get_slurm_partition() + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--ngpus", + "--gpus", + "--gpus-per-node", + default=8, + type=int, + help="Number of GPUs to request on each node", + ) + parser.add_argument( + "--nodes", + "--nnodes", + default=1, + type=int, + help="Number of nodes to request", + ) + parser.add_argument( + "--timeout", + default=2800, + type=int, + help="Duration of the job", + ) + parser.add_argument( + "--partition", + default=slurm_partition, + type=str, + help="Partition where to submit", + ) + parser.add_argument( + "--use-volta32", + action="store_true", + help="Request V100-32GB GPUs", + ) + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + parser.add_argument( + "--exclude", + default="", + type=str, + help="Nodes to exclude", + ) + return parser + + +def get_shared_folder() -> Path: + user_checkpoint_path = get_user_checkpoint_path() + if user_checkpoint_path is None: + raise RuntimeError("Path to user checkpoint cannot be determined") + path = user_checkpoint_path / "experiments" + path.mkdir(exist_ok=True) + return path + + +def submit_jobs(task_class, args, name: str): + if not args.output_dir: + args.output_dir = str(get_shared_folder() / "%j") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) + + kwargs = {} + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + if args.comment: + kwargs["slurm_comment"] = args.comment + if args.exclude: + kwargs["slurm_exclude"] = args.exclude + + executor_params = get_slurm_executor_parameters( + nodes=args.nodes, + num_gpus_per_node=args.ngpus, + timeout_min=args.timeout, # max is 60 * 72 + slurm_signal_delay_s=120, + slurm_partition=args.partition, + **kwargs, + ) + executor.update_parameters(name=name, **executor_params) + + task = task_class(args) + job = executor.submit(task) + + logger.info(f"Submitted job_id: {job.job_id}") + str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) + logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") diff --git a/dinov2/dinov2/run/train/train.py b/dinov2/dinov2/run/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c2366e9bf79765e6abcd70dda6b43f31cb7093eb --- /dev/null +++ b/dinov2/dinov2/run/train/train.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.logging import setup_logging +from dinov2.train import get_args_parser as get_train_args_parser +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.train import main as train_main + + self._setup_args() + train_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 training" + train_args_parser = get_train_args_parser(add_help=False) + parents = [train_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Trainer, args, name="dinov2:train") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/dinov2/train/__init__.py b/dinov2/dinov2/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1752922d04fff0112eb7796be28ff6b68c6073 --- /dev/null +++ b/dinov2/dinov2/train/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .train import get_args_parser, main +from .ssl_meta_arch import SSLMetaArch diff --git a/dinov2/dinov2/train/ssl_meta_arch.py b/dinov2/dinov2/train/ssl_meta_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccf15e904ebeb6134dfb4f5c99da4fc8d41b8e4 --- /dev/null +++ b/dinov2/dinov2/train/ssl_meta_arch.py @@ -0,0 +1,400 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial +import logging + +import torch +from torch import nn + +from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss +from dinov2.models import build_model_from_cfg +from dinov2.layers import DINOHead +from dinov2.utils.utils import has_batchnorms +from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups +from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model + +from dinov2.models.vision_transformer import BlockChunk + + +try: + from xformers.ops import fmha +except ImportError: + raise AssertionError("xFormers is required for training") + + +logger = logging.getLogger("dinov2") + + +class SSLMetaArch(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None + + student_model_dict = dict() + teacher_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + if cfg.student.pretrained_weights: + chkpt = torch.load(cfg.student.pretrained_weights) + logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") + student_backbone.load_state_dict(chkpt["model"], strict=False) + + self.embed_dim = embed_dim + self.dino_out_dim = cfg.dino.head_n_prototypes + + self.do_dino = cfg.dino.loss_weight > 0 + self.do_koleo = cfg.dino.koleo_loss_weight > 0 + self.do_ibot = cfg.ibot.loss_weight > 0 + self.ibot_separate_head = cfg.ibot.separate_head + + logger.info("OPTIONS -- DINO") + if self.do_dino: + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + self.dino_loss_weight = cfg.dino.loss_weight + dino_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + self.dino_loss = DINOLoss(self.dino_out_dim) + if self.do_koleo: + logger.info("OPTIONS -- DINO -- applying KOLEO regularization") + self.koleo_loss = KoLeoLoss() + + else: + logger.info("OPTIONS -- DINO -- not using DINO") + + if self.do_dino or self.do_ibot: + student_model_dict["dino_head"] = dino_head() + teacher_model_dict["dino_head"] = dino_head() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + if self.do_ibot: + self.ibot_loss_weight = cfg.ibot.loss_weight + assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" + assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" + self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes + self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + if self.ibot_separate_head: + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + ibot_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head() + teacher_model_dict["ibot_head"] = ibot_head() + else: + logger.info("OPTIONS -- IBOT -- head shared with DINO") + + self.need_to_synchronize_fsdp_streams = True + + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.requires_grad = False + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + if self.fp16_scaler is not None: + self.fp16_scaler.scale(loss).backward() + else: + loss.backward() + + def forward_backward(self, images, teacher_temp): + n_global_crops = 2 + assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + + global_crops = images["collated_global_crops"].cuda(non_blocking=True) + local_crops = images["collated_local_crops"].cuda(non_blocking=True) + + masks = images["collated_masks"].cuda(non_blocking=True) + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) + n_masked_patches = mask_indices_list.shape[0] + upperbound = images["upperbound"] + masks_weight = images["masks_weight"].cuda(non_blocking=True) + + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + + do_dino = self.do_dino + do_ibot = self.do_ibot + + # loss scales + ibot_loss_scale = 1.0 / n_global_crops + + # teacher output + @torch.no_grad() + def get_teacher_output(): + x, n_global_crops_teacher = global_crops, n_global_crops + teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] + teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) + # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss + teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) + ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = ibot_teacher_patch_tokens.shape[-1] + n_cls_tokens = teacher_cls_tokens.shape[0] + + if do_ibot and not self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], + ) + tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) + teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] + masked_teacher_patch_tokens_after_head = tokens_after_head[ + n_cls_tokens : n_cls_tokens + n_masked_patches + ] + elif do_ibot and self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[:n_masked_patches], + ) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ + :n_masked_patches + ] + else: + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_ibot_softmaxed_centered = None + + if self.cfg.train.centering == "centering": + teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + self.dino_loss.update_center(teacher_cls_tokens_after_head) + if do_ibot: + masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( + masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp + ) + masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) + self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) + + elif self.cfg.train.centering == "sinkhorn_knopp": + teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + + if do_ibot: + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_teacher_patch_tokens_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) + + else: + raise NotImplementedError + + return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + reshard_fsdp_model(self.teacher) + + loss_dict = {} + + loss_accumulator = 0 # for backprop + student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( + [global_crops, local_crops], masks=[masks, None], is_training=True + ) + + inputs_for_student_head_list = [] + + # 1a: local crops cls tokens + student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + + # 1b: global crops cls tokens + student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + + # 1c: global crops patch tokens + if do_ibot: + _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] + ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) + buffer_tensor_patch_tokens[:n_masked_patches].copy_( + torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) + ) + if not self.ibot_separate_head: + inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + else: + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ + :n_masked_patches + ] + + # 2: run + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) + outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) + + # 3a: local crops cls tokens + student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3b: global crops cls tokens + student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3c: global crops patch tokens + if do_ibot and not self.ibot_separate_head: + student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] + + if n_local_crops > 0: + dino_local_crops_loss = self.dino_loss( + student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), + teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, + ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) + + # store for display + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_local_crops_loss + + # process global crops + loss_scales = 2 # this is here since we process global crops together + + if do_dino: + # compute loss + dino_global_crops_loss = ( + self.dino_loss( + student_output_list=[student_global_cls_tokens_after_head], + teacher_out_softmaxed_centered_list=[ + teacher_dino_softmaxed_centered_list.flatten(0, 1) + ], # these were chunked and stacked in reverse so A is matched to B + ) + * loss_scales + / (n_global_crops_loss_terms + n_local_crops_loss_terms) + ) + + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_global_crops_loss + + student_cls_tokens = student_global_cls_tokens + + if self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss / loss_scales + ) # this is to display the same losses as before but we can remove eventually + + if do_ibot: + # compute loss + ibot_patch_loss = ( + self.ibot_patch_loss.forward_masked( + student_global_masked_patch_tokens_after_head, + masked_teacher_ibot_softmaxed_centered, + student_masks_flat=masks, + n_masked_patches=n_masked_patches, + masks_weight=masks_weight, + ) + * loss_scales + * ibot_loss_scale + ) + + # store for display + loss_dict["ibot_loss"] = ibot_patch_loss / 2 + + # accumulate loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + self.backprop_loss(loss_accumulator) + + self.fsdp_synchronize_streams() + + return loss_dict + + def fsdp_synchronize_streams(self): + if self.need_to_synchronize_fsdp_streams: + torch.cuda.synchronize() + self.student.dino_head._streams = ( + self.teacher.dino_head._streams + ) = self.student.backbone._streams = self.teacher.backbone._streams + self.need_to_synchronize_fsdp_streams = False + + def update_teacher(self, m): + student_param_list = [] + teacher_param_list = [] + with torch.no_grad(): + for k in self.student.keys(): + for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): + student_param_list += ms.params + teacher_param_list += mt.params + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def train(self): + super().train() + self.teacher.eval() + + def get_maybe_fused_params_for_submodel(self, m): + params_groups = get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + return fused_params_groups + + def get_params_groups(self): + all_params_groups = [] + for m in self.student.values(): + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self): + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if has_batchnorms(self.student): + raise NotImplementedError + # below will synchronize all student subnetworks across gpus: + for k, v in self.student.items(): + self.teacher[k].load_state_dict(self.student[k].state_dict()) + student_model_cfg = self.cfg.compute_precision.student[k] + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + teacher_model_cfg = self.cfg.compute_precision.teacher[k] + self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/dinov2/dinov2/train/train.py b/dinov2/dinov2/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..473b8d01473654182de9f91c94a2d8720fe096a5 --- /dev/null +++ b/dinov2/dinov2/train/train.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import os +from functools import partial + +from fvcore.common.checkpoint import PeriodicCheckpointer +import torch + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +import dinov2.distributed as distributed +from dinov2.fsdp import FSDPCheckpointer +from dinov2.logging import MetricLogger +from dinov2.utils.config import setup +from dinov2.utils.utils import CosineScheduler + +from dinov2.train.ssl_meta_arch import SSLMetaArch + + +torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default +logger = logging.getLogger("dinov2") + + +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + default="", + type=str, + help="Output directory to save logs and checkpoints", + ) + + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim["lr"], + final_value=cfg.optim["min_lr"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + ) + wd = dict( + base_value=cfg.optim["weight_decay"], + final_value=cfg.optim["weight_decay_end"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + momentum = dict( + base_value=cfg.teacher["momentum_teacher"], + final_value=cfg.teacher["final_momentum_teacher"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + teacher_temp = dict( + base_value=cfg.teacher["teacher_temp"], + final_value=cfg.teacher["teacher_temp"], + total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher["warmup_teacher_temp"], + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[ + : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH + ] = 0 # mimicking the original schedules + + logger.info("Schedulers ready.") + + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier + + +def do_test(cfg, model, iteration): + new_state_dict = model.teacher.state_dict() + + if distributed.is_main_process(): + iterstring = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) + os.makedirs(eval_dir, exist_ok=True) + # save teacher checkpoint + teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + torch.save({"teacher": new_state_dict}, teacher_ckp_path) + + +def do_train(cfg, model, resume=False): + model.train() + inputs_dtype = torch.half + fp16_scaler = model.fp16_scaler # for mixed precision training + + # setup optimizer + + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + + # checkpointer + checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + + start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, + period=3 * OFFICIAL_EPOCH_LENGTH, + max_iter=max_iter, + max_to_keep=3, + ) + + # setup data preprocessing + + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + n_tokens=n_tokens, + mask_generator=mask_generator, + dtype=inputs_dtype, + ) + + # setup data loader + + dataset = make_dataset( + dataset_str=cfg.train.dataset_path, + transform=data_transform, + target_transform=lambda _: (), + ) + # sampler_type = SamplerType.INFINITE + sampler_type = SamplerType.SHARDED_INFINITE + data_loader = make_data_loader( + dataset=dataset, + batch_size=cfg.train.batch_size_per_gpu, + num_workers=cfg.train.num_workers, + shuffle=True, + seed=start_iter, # TODO: Fix this -- cfg.train.seed + sampler_type=sampler_type, + sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + + # training loop + + iteration = start_iter + + logger.info("Starting training from iteration {}".format(start_iter)) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + header = "Training" + + for data in metric_logger.log_every( + data_loader, + 10, + header, + max_iter, + start_iter, + ): + current_batch_size = data["collated_global_crops"].shape[0] / 2 + if iteration > max_iter: + return + + # apply schedules + + lr = lr_schedule[iteration] + wd = wd_schedule[iteration] + mom = momentum_schedule[iteration] + teacher_temp = teacher_temp_schedule[iteration] + last_layer_lr = last_layer_lr_schedule[iteration] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # compute losses + + optimizer.zero_grad(set_to_none=True) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + + # clip gradients + + if fp16_scaler is not None: + if cfg.optim.clip_grad: + fp16_scaler.unscale_(optimizer) + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + fp16_scaler.step(optimizer) + fp16_scaler.update() + else: + if cfg.optim.clip_grad: + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + optimizer.step() + + # perform teacher EMA update + + model.update_teacher(mom) + + # logging + + if distributed.get_global_size() > 1: + for v in loss_dict.values(): + torch.distributed.all_reduce(v) + loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + + if math.isnan(sum(loss_dict_reduced.values())): + logger.info("NaN detected") + raise AssertionError + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(current_batch_size=current_batch_size) + metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + + # checkpointing and testing + + if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: + do_test(cfg, model, f"training_{iteration}") + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + cfg = setup(args) + + model = SSLMetaArch(cfg).to(torch.device("cuda")) + model.prepare_for_distributed_training() + + logger.info("Model:\n{}".format(model)) + if args.eval_only: + iteration = ( + FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, f"manual_{iteration}") + + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + args = get_args_parser(add_help=True).parse_args() + main(args) diff --git a/dinov2/dinov2/utils/__init__.py b/dinov2/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/dinov2/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/dinov2/utils/cluster.py b/dinov2/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/dinov2/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/dinov2/dinov2/utils/config.py b/dinov2/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/dinov2/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/dinov2/dinov2/utils/dtype.py b/dinov2/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/dinov2/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/dinov2/dinov2/utils/param_groups.py b/dinov2/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/dinov2/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/dinov2/dinov2/utils/utils.py b/dinov2/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/dinov2/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/open_clip_local/__init__.py b/open_clip_local/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23856a3f13d8ae592b343131345108b3432e43a3 --- /dev/null +++ b/open_clip_local/__init__.py @@ -0,0 +1,16 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ + get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/open_clip_local/big_vision.py b/open_clip_local/big_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7eaf3fa543dba7d7517ac566c6364a5a893796 --- /dev/null +++ b/open_clip_local/big_vision.py @@ -0,0 +1,136 @@ +import torch +import numpy as np + +from .model import CustomTextCLIP +from .transformer import TextTransformer, Transformer + + +@torch.no_grad() +def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): + """ Load weights from .npz checkpoints for official Google big_vision image-text models + + Currently the SigLIP source models are supported and a CustomTextCLIP destination model + w/ timm image encoder. + """ + from timm.layers import resample_patch_embed, resample_abs_pos_embed + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + interpolation = 'bilinear' + antialias = False + + def _convert_timm_img(module, prefix): + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + module.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + module.patch_embed.proj.weight.copy_(embed_conv_w) + module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + + if module.cls_token is not None: + module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) + if pos_embed_w.shape != module.pos_embed.shape: + assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' + num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + new_size=module.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + module.pos_embed.copy_(pos_embed_w) + + mha_sub, b_sub, ln1_sub = (0, 0, 1) + for i, block in enumerate(module.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + + module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + + if module.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) + module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) + module.attn_pool.kv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) + module.attn_pool.kv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) + module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + for r in range(2): + getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + + def _convert_openclip_transformer(module: Transformer, prefix): + for i, block in enumerate(module.resblocks.children()): + block_prefix = f'{prefix}encoderblock_{i}/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.in_proj_weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.in_proj_bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) + block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) + block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) + block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) + block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) + block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) + + def _convert_openclip_txt(module: TextTransformer, prefix): + module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) + module.positional_embedding.copy_(pos_embed_w) + _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') + module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) + module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) + module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + + _convert_timm_img(model.visual.trunk, 'params/img/') + _convert_openclip_txt(model.text, 'params/txt/') + model.logit_bias.copy_(_n2p(w['params/b'])[0]) + model.logit_scale.copy_(_n2p(w['params/t'])[0]) + + diff --git a/open_clip_local/bpe_simple_vocab_16e6.txt.gz b/open_clip_local/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/open_clip_local/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/open_clip_local/coca_model.py b/open_clip_local/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4faa061deabc770e1f4006b51a6adaa287a8e90 --- /dev/null +++ b/open_clip_local/coca_model.py @@ -0,0 +1,492 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None + self.pad_id = pad_id + + self.context_length = multimodal_cfg.context_length + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize: bool = True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize: bool = True): + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize: bool = True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize: bool = True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + + def forward( + self, + image, + text: Optional[torch.Tensor] = None, + image_latent: Optional[torch.Tensor] = None, + image_embs: Optional[torch.Tensor] = None, + output_labels: bool = True, + ): + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + if text is None: + return {"image_features": image_latent, "image_embs": image_embs} + + text_latent, token_embs = self._encode_text(text) + + # FIXME this isn't an ideal solution, would like to improve -RW + labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None + if output_labels: + # align text_embs and thus logits with labels for teacher-forcing caption loss + token_embs = token_embs[:, :-1] + + logits = self.text_decoder(image_embs, token_embs) + out_dict = { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "logit_scale": self.logit_scale.exp() + } + if labels is not None: + out_dict["labels"] = labels + if self.logit_bias is not None: + out_dict["logit_bias"] = self.logit_bias + return out_dict + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs=image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + pad_len = seq_len - output.shape[1] + return torch.cat(( + output, + torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * self.pad_id + ), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self( + image, + x, + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + )["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/open_clip_local/constants.py b/open_clip_local/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..599c48c03f7a1ed97af20cbc482db27984514622 --- /dev/null +++ b/open_clip_local/constants.py @@ -0,0 +1,6 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +INCEPTION_MEAN = (0.5, 0.5, 0.5) +INCEPTION_STD = (0.5, 0.5, 0.5) diff --git a/open_clip_local/factory.py b/open_clip_local/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..627c7003d974c04939fbaca6d121713c8e6e23b7 --- /dev/null +++ b/open_clip_local/factory.py @@ -0,0 +1,463 @@ +import json +import logging +import os +import re +from copy import deepcopy +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs +from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def _get_hf_config(model_id, cache_dir=None): + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config + + +def get_tokenizer( + model_name: str = '', + context_length: Optional[int] = None, + **kwargs, +): + if model_name.startswith(HF_HUB_PREFIX): + model_name = model_name[len(HF_HUB_PREFIX):] + try: + config = _get_hf_config(model_name)['model_cfg'] + except Exception: + tokenizer = HFTokenizer( + model_name, + context_length=context_length or DEFAULT_CONTEXT_LENGTH, + **kwargs, + ) + return tokenizer + else: + config = get_model_config(model_name) + assert config is not None, f"No valid model config found for {model_name}." + + text_config = config.get('text_cfg', {}) + if 'tokenizer_kwargs' in text_config: + tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) + else: + tokenizer_kwargs = kwargs + + if context_length is None: + context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) + + if 'hf_tokenizer_name' in text_config: + tokenizer = HFTokenizer( + text_config['hf_tokenizer_name'], + context_length=context_length, + **tokenizer_kwargs, + ) + else: + tokenizer = SimpleTokenizer( + context_length=context_length, + **tokenizer_kwargs, + ) + + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, torch.jit.ScriptModule): + state_dict = checkpoint.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + if Path(checkpoint_path).suffix in ('.npz', '.npy'): + from .big_vision import load_big_vision_weights + load_big_vision_weights(model, checkpoint_path) + return {} + + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 + if 'logit_bias' not in state_dict and model.logit_bias is not None: + state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) + # Certain text transformers no longer expect position_ids after transformers==4.31 + position_id_key = 'text.transformer.embeddings.position_ids' + if position_id_key in state_dict and not hasattr(model, position_id_key): + del state_dict[position_id_key] + resize_pos_embed(state_dict, model) + resize_text_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + force_preprocess_cfg: Optional[Dict[str, Any]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + **model_kwargs, +): + force_preprocess_cfg = force_preprocess_cfg or {} + preprocess_cfg = asdict(PreprocessCfg()) + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config = _get_hf_config(model_id, cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) + model_cfg = config['model_cfg'] + pretrained_hf = False # override, no need to load original HF text weights + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + model_cfg['vision_cfg']['image_size'] = 448 + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + # cache_dir = '/data/yizhou/VAND2.0/wgd/single_scale' + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + # set image preprocessing configuration in model attributes for convenience + if getattr(model.visual, 'image_size', None) is not None: + # use image_size set on model creation (via config or force_image_size arg) + force_preprocess_cfg['size'] = model.visual.image_size + set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif args.siglip: + assert not args.horovod, "Horovod not currently supported for SigLip" + return SigLipLoss( + rank=args.rank, + world_size=args.world_size, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + **model_kwargs, +): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + **model_kwargs, + ) + + model.visual.preprocess_cfg['resize_mode'] = 'squash' + pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) + + preprocess_train = image_transform_v2( + pp_cfg, + is_train=True, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform_v2( + pp_cfg, + is_train=False, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + return_transform: bool = True, + cache_dir: Optional[str] = None, + **model_kwargs, +): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, + cache_dir=cache_dir, + require_pretrained=True, + **model_kwargs, + ) + + if not return_transform: + return model + + preprocess = image_transform_v2( + PreprocessCfg(**model.visual.preprocess_cfg), + is_train=False, + ) + + return model, preprocess diff --git a/open_clip_local/hf_configs.py b/open_clip_local/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2067476500a7c16511af18696fc5e23b066aff --- /dev/null +++ b/open_clip_local/hf_configs.py @@ -0,0 +1,67 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/m2m_100 + "m2m_100": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "encoder_attention_heads", + "layers": "encoder_layers", + }, + "pooler": "cls_pooler", + }, +} diff --git a/open_clip_local/hf_model.py b/open_clip_local/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..281a06cc5f16f41e17ba0e6ea9b5b29fab5bc076 --- /dev/null +++ b/open_clip_local/hf_model.py @@ -0,0 +1,193 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = torch.tensor(0) + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss + + +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + tensor_from_left = torch.zeros_like(tensor_to_right) + tensor_from_right = torch.zeros_like(tensor_to_left) + send_op_left = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_left, + left_rank, + group=group, + ) + send_op_right = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_right, + right_rank, + group=group, + ) + recv_op_left = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_left, + left_rank, + group=group, + ) + recv_op_right = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_right, + right_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) + for req in reqs: + req.wait() + return tensor_from_right, tensor_from_left + + +class NeighbourExchange(torch.autograd.Function): + @staticmethod + def forward(ctx, from_rank, to_rank, group, tensor): + ctx.group = group + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) + + +def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): + return NeighbourExchange.apply(from_rank, to_rank, group, tensor) + + +class NeighbourExchangeBidir(torch.autograd.Function): + @staticmethod + def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): + ctx.group = group + ctx.left_rank = left_rank + ctx.right_rank = right_rank + return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None, None) + \ + NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) + + +def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) + + +class SigLipLoss(nn.Module): + """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 + + @article{zhai2023sigmoid, + title={Sigmoid loss for language image pre-training}, + author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, + journal={arXiv preprint arXiv:2303.15343}, + year={2023} + } + """ + def __init__( + self, + cache_labels=False, + rank=0, + world_size=1, + bidir=True, + use_horovod=False, + ): + super().__init__() + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + assert not use_horovod # FIXME need to look at hvd ops for ring transfers + self.use_horovod = use_horovod + self.bidir = bidir + + # cache state FIXME cache not currently used, worthwhile? + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: + labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) + if not negative_only: + labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels + return labels + + def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): + logits = logit_scale * image_features @ text_features.T + if logit_bias is not None: + logits += logit_bias + return logits + + def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): + logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) + labels = self.get_ground_truth( + image_features.device, + image_features.dtype, + image_features.shape[0], + negative_only=negative_only, + ) + loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] + return loss + + def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): + loss = self._loss(image_features, text_features, logit_scale, logit_bias) + + if self.world_size > 1: + # exchange text features w/ neighbour world_size - 1 times + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + if self.bidir: + text_features_to_right = text_features_to_left = text_features + num_bidir, remainder = divmod(self.world_size - 1, 2) + for i in range(num_bidir): + text_features_recv = neighbour_exchange_bidir_with_grad( + left_rank, + right_rank, + text_features_to_left, + text_features_to_right, + ) + + for f in text_features_recv: + loss += self._loss( + image_features, + f, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_left, text_features_to_right = text_features_recv + + if remainder: + text_features_recv = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right) + + loss += self._loss( + image_features, + text_features_recv, + logit_scale, + logit_bias, + negative_only=True, + ) + else: + text_features_to_right = text_features + for i in range(self.world_size - 1): + text_features_from_left = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right) + + loss += self._loss( + image_features, + text_features_from_left, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_right = text_features_from_left + + return {"contrastive_loss": loss} if output_dict else loss diff --git a/open_clip_local/model.py b/open_clip_local/model.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75a4b0dc6b562a000036d6cc7f33350a7e8bee --- /dev/null +++ b/open_clip_local/model.py @@ -0,0 +1,625 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import copy +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint +from functools import partial + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ + text_global_pool +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) + attn_pooler_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + no_ln_pre: bool = False # disable pre transformer LayerNorm + pos_embed_type: str = 'learnable' + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'tok' + output_tokens: bool = False + act_kwargs: Optional[dict] = None + norm_kwargs: Optional[dict] = None + + timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + hf_tokenizer_name: Optional[str] = None + tokenizer_kwargs: Optional[dict] = None + + width: int = 512 + heads: int = 8 + layers: int = 12 + mlp_ratio: float = 4.0 + ls_init_value: Optional[float] = None # layer scale initial value + embed_cls: bool = False + pad_id: int = 0 + no_causal_mask: bool = False # disable causal masking + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'argmax' + proj_bias: bool = False + output_tokens: bool = False + act_kwargs: dict = None + norm_kwargs: dict = None + + # HuggingFace specific text tower config + hf_model_name: Optional[str] = None + hf_model_pretrained: bool = True + hf_proj_type: str = 'mlp' + hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if vision_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) + if vision_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **vision_cfg.act_kwargs) + + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + attentional_pool=vision_cfg.attentional_pool, + attn_pooler_queries=vision_cfg.attn_pooler_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + pos_embed_type=vision_cfg.pos_embed_type, + no_ln_pre=vision_cfg.no_ln_pre, + final_ln_after_pool=vision_cfg.final_ln_after_pool, + pool_type=vision_cfg.pool_type, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj_type=text_cfg.hf_proj_type, + pooler_type=text_cfg.hf_pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if text_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) + if text_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **text_cfg.act_kwargs) + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + mlp_ratio=text_cfg.mlp_ratio, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + no_causal_mask=text_cfg.no_causal_mask, + pad_id=text_cfg.pad_id, + pool_type=text_cfg.pool_type, + proj_bias=text_cfg.proj_bias, + output_tokens=text_cfg.output_tokens, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.text_pool_type = text.pool_type + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, out_layers, normalize: bool = False): + features = self.visual(image, out_layers) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + # x = self.transformer(x, attn_mask=self.attn_mask) + x, _, _ = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + x, _ = text_global_pool(x, text, self.text_pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + x = self.text_projection(x) + else: + x = x @ self.text_projection + + return F.normalize(x, dim=-1) if normalize else x + + def get_logits(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + image_logits = self.logit_scale.exp() * image_features @ text_features.T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + return image_logits, text_logits + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + + if self.output_dict: + out_dict = { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def get_logits(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + image_logits = self.logit_scale.exp() * image_features @ text_features.T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + return image_logits, text_logits + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + + if self.output_dict: + out_dict = { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed + + +def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): + old_pos_embed = state_dict.get('positional_embedding', None) + if old_pos_embed is None: + return + # FIXME add support for text cls_token + model_pos_embed = getattr(model, 'positional_embedding', None) + if model_pos_embed is None: + model_pos_embed = getattr(model.text, 'positional_embedding', None) + + old_num_pos = old_pos_embed.shape[0] + old_width = old_pos_embed.shape[1] + num_pos = model_pos_embed.shape[0] + width = model_pos_embed.shape[1] + assert old_width == width, 'text pos_embed width changed!' + if old_num_pos == num_pos: + return + + logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) + old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) + old_pos_embed = F.interpolate( + old_pos_embed, + size=num_pos, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] + new_pos_embed = old_pos_embed + + state_dict['positional_embedding'] = new_pos_embed + + +def get_model_preprocess_cfg(model): + module = getattr(model, 'visual', model) + preprocess_cfg = getattr(module, 'preprocess_cfg', {}) + if not preprocess_cfg: + # use separate legacy attributes if preprocess_cfg dict not found + size = getattr(module, 'image_size') + if size is not None: + preprocess_cfg['size'] = size + mean = getattr(module, 'image_mean', None) + if mean is not None: + preprocess_cfg['mean'] = mean + std = getattr(module, 'image_std', None) + if std is not None: + preprocess_cfg['std'] = std + return preprocess_cfg + + +def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): + module = getattr(model, 'visual', model) + module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat + module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat + module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict + + +def get_model_tokenize_cfg(model): + module = getattr(model, 'text', model) + cfg = {} + context_length = getattr(module, 'context_length', None) + if context_length is not None: + cfg['context_length'] = context_length + vocab_size = getattr(module, 'vocab_size', None) + if vocab_size is not None: + cfg['vocab_size'] = vocab_size + return cfg diff --git a/open_clip_local/model_configs/EVA01-g-14-plus.json b/open_clip_local/model_configs/EVA01-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..73f46a71e664fce987218b8eb48903e7bd895f41 --- /dev/null +++ b/open_clip_local/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA01-g-14.json b/open_clip_local/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/open_clip_local/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA02-B-16.json b/open_clip_local/model_configs/EVA02-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..3f92357287e1f6600da1e7f391cb6370d7f66de4 --- /dev/null +++ b/open_clip_local/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA02-E-14-plus.json b/open_clip_local/model_configs/EVA02-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..e250c2a404c86ff168c54cfcf71bc2492be1b74c --- /dev/null +++ b/open_clip_local/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA02-E-14.json b/open_clip_local/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/open_clip_local/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA02-L-14-336.json b/open_clip_local/model_configs/EVA02-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb07f3c082fd88c4e86131b272163aaacfaef9e --- /dev/null +++ b/open_clip_local/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/EVA02-L-14.json b/open_clip_local/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/open_clip_local/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN101-quickgelu.json b/open_clip_local/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/open_clip_local/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN101.json b/open_clip_local/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/open_clip_local/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN50-quickgelu.json b/open_clip_local/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/open_clip_local/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip_local/model_configs/RN50.json b/open_clip_local/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/open_clip_local/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN50x16.json b/open_clip_local/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/open_clip_local/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN50x4.json b/open_clip_local/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/open_clip_local/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/RN50x64.json b/open_clip_local/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/open_clip_local/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-SigLIP-256.json b/open_clip_local/model_configs/ViT-B-16-SigLIP-256.json new file mode 100644 index 0000000000000000000000000000000000000000..d7ad3acba6bd37701ff8f19ca5f791c6342b73d6 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-SigLIP-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_base_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-SigLIP-384.json b/open_clip_local/model_configs/ViT-B-16-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..df9a25cdca5207a8954801c0f2cf28514c15a1cd --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_base_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-SigLIP-512.json b/open_clip_local/model_configs/ViT-B-16-SigLIP-512.json new file mode 100644 index 0000000000000000000000000000000000000000..88b018528b2e7806cd11b95d5808136786ea0f97 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-SigLIP-512.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 512, + "timm_model_name": "vit_base_patch16_siglip_512", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-SigLIP-i18n-256.json b/open_clip_local/model_configs/ViT-B-16-SigLIP-i18n-256.json new file mode 100644 index 0000000000000000000000000000000000000000..7a28797a7e1487af986540872447a68da0dd69b2 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-SigLIP-i18n-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_base_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 250000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-SigLIP.json b/open_clip_local/model_configs/ViT-B-16-SigLIP.json new file mode 100644 index 0000000000000000000000000000000000000000..a9f2b654a671c9bd235f351b2a253ca889758549 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-SigLIP.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "vit_base_patch16_siglip_224", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-plus-240.json b/open_clip_local/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-plus.json b/open_clip_local/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16-quickgelu.json b/open_clip_local/model_configs/ViT-B-16-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ff5431ea3065d18094de94d3c87d8814d3f651fe --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-16.json b/open_clip_local/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-32-256.json b/open_clip_local/model_configs/ViT-B-32-256.json new file mode 100644 index 0000000000000000000000000000000000000000..80a2597d8f7d5d500df2aacbded9507196dad6da --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-32-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip_local/model_configs/ViT-B-32-plus-256.json b/open_clip_local/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-32-quickgelu.json b/open_clip_local/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-B-32.json b/open_clip_local/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/open_clip_local/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-14-378-quickgelu.json b/open_clip_local/model_configs/ViT-H-14-378-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..e2b2ecf9ae278eeb4f6b20d16e17a6523f961580 --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-14-378-quickgelu.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 378, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-14-CLIPA-336.json b/open_clip_local/model_configs/ViT-H-14-CLIPA-336.json new file mode 100644 index 0000000000000000000000000000000000000000..01fabb29db2bcbd9513e903064d61e3e1974d580 --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-14-CLIPA-336.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 336, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-14-CLIPA.json b/open_clip_local/model_configs/ViT-H-14-CLIPA.json new file mode 100644 index 0000000000000000000000000000000000000000..7df0338844bfff4d30f3ca08711311f645dda866 --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-14-CLIPA.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-14-quickgelu.json b/open_clip_local/model_configs/ViT-H-14-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..41f22f65bb002c320111790e0cd0f2425a575df7 --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-14-quickgelu.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-14.json b/open_clip_local/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-H-16.json b/open_clip_local/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/open_clip_local/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14-280.json b/open_clip_local/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14-336.json b/open_clip_local/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14-CLIPA-336.json b/open_clip_local/model_configs/ViT-L-14-CLIPA-336.json new file mode 100644 index 0000000000000000000000000000000000000000..60a4df589b9e9ed269807204ec9788e613026382 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14-CLIPA-336.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14-CLIPA.json b/open_clip_local/model_configs/ViT-L-14-CLIPA.json new file mode 100644 index 0000000000000000000000000000000000000000..b4dde7b546b6c53d5c55f2abe50b599ff2519964 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14-CLIPA.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14-quickgelu.json b/open_clip_local/model_configs/ViT-L-14-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d5a3fd36aa9cd9cc4a3dc29e362945cec13a02f3 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-14.json b/open_clip_local/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-16-320.json b/open_clip_local/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-16-SigLIP-256.json b/open_clip_local/model_configs/ViT-L-16-SigLIP-256.json new file mode 100644 index 0000000000000000000000000000000000000000..5ba8f7abb68e5a798d38f976a828c63f74b94ae8 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-16-SigLIP-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 1024, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_large_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1024, + "heads": 16, + "layers": 24, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-16-SigLIP-384.json b/open_clip_local/model_configs/ViT-L-16-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..fd2cc2e346f7110a5de01cfaf7eae8c94360de3a --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 1024, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_large_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1024, + "heads": 16, + "layers": 24, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-L-16.json b/open_clip_local/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/open_clip_local/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-M-16-alt.json b/open_clip_local/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/open_clip_local/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-M-16.json b/open_clip_local/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/open_clip_local/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-M-32-alt.json b/open_clip_local/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/open_clip_local/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-M-32.json b/open_clip_local/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/open_clip_local/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-S-16-alt.json b/open_clip_local/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/open_clip_local/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-S-16.json b/open_clip_local/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/open_clip_local/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-S-32-alt.json b/open_clip_local/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/open_clip_local/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-S-32.json b/open_clip_local/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/open_clip_local/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-SO400M-14-SigLIP-384.json b/open_clip_local/model_configs/ViT-SO400M-14-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..4c527f581230938d7b39baf36b6bd749b0e7f169 --- /dev/null +++ b/open_clip_local/model_configs/ViT-SO400M-14-SigLIP-384.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_so400m_patch14_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-SO400M-14-SigLIP.json b/open_clip_local/model_configs/ViT-SO400M-14-SigLIP.json new file mode 100644 index 0000000000000000000000000000000000000000..564eb78a49c8ff31cac047277b9344bbe85fef40 --- /dev/null +++ b/open_clip_local/model_configs/ViT-SO400M-14-SigLIP.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "vit_so400m_patch14_siglip_224", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 16, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-bigG-14-CLIPA-336.json b/open_clip_local/model_configs/ViT-bigG-14-CLIPA-336.json new file mode 100644 index 0000000000000000000000000000000000000000..75ba7675c643cd482f06886e58ded6fb934233fc --- /dev/null +++ b/open_clip_local/model_configs/ViT-bigG-14-CLIPA-336.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 336, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1280, + "heads": 20, + "layers": 32, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-bigG-14-CLIPA.json b/open_clip_local/model_configs/ViT-bigG-14-CLIPA.json new file mode 100644 index 0000000000000000000000000000000000000000..83ec709f8b8362d892067adafde9a0d78ce4db14 --- /dev/null +++ b/open_clip_local/model_configs/ViT-bigG-14-CLIPA.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1280, + "heads": 20, + "layers": 32, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-bigG-14.json b/open_clip_local/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/open_clip_local/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-e-14.json b/open_clip_local/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/open_clip_local/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/ViT-g-14.json b/open_clip_local/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/open_clip_local/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/coca_ViT-B-32.json b/open_clip_local/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/open_clip_local/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/coca_ViT-L-14.json b/open_clip_local/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/open_clip_local/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/open_clip_local/model_configs/coca_base.json b/open_clip_local/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/open_clip_local/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip_local/model_configs/coca_roberta-ViT-B-32.json b/open_clip_local/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..aa9d3f562057f849e6ced8b495de2dd73387fe61 --- /dev/null +++ b/open_clip_local/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "hf_proj_type": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/open_clip_local/model_configs/convnext_base.json b/open_clip_local/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/open_clip_local/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_base_w.json b/open_clip_local/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/open_clip_local/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_base_w_320.json b/open_clip_local/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/open_clip_local/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_large.json b/open_clip_local/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/open_clip_local/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_large_d.json b/open_clip_local/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/open_clip_local/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_large_d_320.json b/open_clip_local/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/open_clip_local/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_small.json b/open_clip_local/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/open_clip_local/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_tiny.json b/open_clip_local/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/open_clip_local/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_xlarge.json b/open_clip_local/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/open_clip_local/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_xxlarge.json b/open_clip_local/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/open_clip_local/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/convnext_xxlarge_320.json b/open_clip_local/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/open_clip_local/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/mt5-base-ViT-B-32.json b/open_clip_local/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..e22366897aa0a6719a09ff4dc168ef9724a3486c --- /dev/null +++ b/open_clip_local/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,14 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip_local/model_configs/mt5-xl-ViT-H-14.json b/open_clip_local/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..f58717cdd5d4980ca2e099d15d5ee1ab7623c230 --- /dev/null +++ b/open_clip_local/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip_local/model_configs/nllb-clip-base-siglip.json b/open_clip_local/model_configs/nllb-clip-base-siglip.json new file mode 100644 index 0000000000000000000000000000000000000000..f7152d0bb6b9fd3333b46cb75934e500f1aab348 --- /dev/null +++ b/open_clip_local/model_configs/nllb-clip-base-siglip.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "custom_text": true, + "init_logit_bias": -10, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_base_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-600M", + "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/nllb-clip-base.json b/open_clip_local/model_configs/nllb-clip-base.json new file mode 100644 index 0000000000000000000000000000000000000000..57265b33f7cfd21b07741744d50cbf30208017d1 --- /dev/null +++ b/open_clip_local/model_configs/nllb-clip-base.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-600M", + "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/nllb-clip-large-siglip.json b/open_clip_local/model_configs/nllb-clip-large-siglip.json new file mode 100644 index 0000000000000000000000000000000000000000..0ac3485762b5117597839b3274ed85340a2c76c2 --- /dev/null +++ b/open_clip_local/model_configs/nllb-clip-large-siglip.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1152, + "custom_text": true, + "init_logit_bias": -10, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_so400m_patch14_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-1.3B", + "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/nllb-clip-large.json b/open_clip_local/model_configs/nllb-clip-large.json new file mode 100644 index 0000000000000000000000000000000000000000..72d04a73316e513135581f563c74f8cb69dac1c9 --- /dev/null +++ b/open_clip_local/model_configs/nllb-clip-large.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-1.3B", + "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/roberta-ViT-B-32.json b/open_clip_local/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..c0c7a55995d50230c6b0f0af5fbd81d5889a3d59 --- /dev/null +++ b/open_clip_local/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip_local/model_configs/swin_base_patch4_window7_224.json b/open_clip_local/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/open_clip_local/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/vit_medium_patch16_gap_256.json b/open_clip_local/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/open_clip_local/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/vit_relpos_medium_patch16_cls_224.json b/open_clip_local/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/open_clip_local/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip_local/model_configs/xlm-roberta-base-ViT-B-32.json b/open_clip_local/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..375fa9e12f1629ef049a715d43ba2a8b1822ff1c --- /dev/null +++ b/open_clip_local/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,14 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip_local/model_configs/xlm-roberta-large-ViT-H-14.json b/open_clip_local/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..c56b4e89883506ce41d0295d9a700b4a3dd2775f --- /dev/null +++ b/open_clip_local/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip_local/modified_resnet.py b/open_clip_local/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c0b033a80e7d08a20a367050c5b1bc5d5292e7 --- /dev/null +++ b/open_clip_local/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/open_clip_local/openai.py b/open_clip_local/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c0235245c2e4f1217b3b2bfaf2acf78e74981 --- /dev/null +++ b/open_clip_local/openai.py @@ -0,0 +1,90 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + state_dict = torch.load(model_path, map_location="cpu") + + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': + model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) + + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD + return model diff --git a/open_clip_local/pos_embed.py b/open_clip_local/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8082b34df2318dd25a4ec8346b3f9a888f38de --- /dev/null +++ b/open_clip_local/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed diff --git a/open_clip_local/pretrained.py b/open_clip_local/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e43e773fda47916398b11379e8685adf8933ba01 --- /dev/null +++ b/open_clip_local/pretrained.py @@ -0,0 +1,586 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .constants import ( + IMAGENET_MEAN, + IMAGENET_STD, + INCEPTION_MEAN, + INCEPTION_STD, + OPENAI_DATASET_MEAN, + OPENAI_DATASET_STD, +) +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', **kwargs): + # OpenAI / OpenCLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': OPENAI_DATASET_MEAN, + 'std': OPENAI_DATASET_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'shortest', + **kwargs, + } + + +def _slpcfg(url='', hf_hub='', **kwargs): + # SiGLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': INCEPTION_MEAN, + 'std': INCEPTION_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'squash', + **kwargs, + } + + +def _apcfg(url='', hf_hub='', **kwargs): + # CLIPA defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': IMAGENET_MEAN, + 'std': IMAGENET_STD, + 'interpolation': 'bilinear', + 'resize_mode': 'squash', + **kwargs, + } + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + metaclip_400m=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), + metaclip_fullcc=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), +) + +_VITB32_256 = dict( + datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), + # DFN + dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') +) + +_VITB16_quickgelu = dict( + metaclip_400m=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), + metaclip_fullcc=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=INCEPTION_MEAN, std=INCEPTION_STD), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), +) + +_VITL14_quickgelu = dict( + metaclip_400m=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), + metaclip_fullcc=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), + dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITH14_quickgelu = dict( + metaclip_fullcc=_pcfg( + "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14/', + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITH14_378_quickgelu = dict( + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + + "ViT-B-32": _VITB32, + "ViT-B-32-256": _VITB32_256, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-quickgelu": _VITB16_quickgelu, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-quickgelu": _VITL14_quickgelu, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-H-14-quickgelu": _VITH14_quickgelu, + "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ), + + "ViT-B-16-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), + ), + "ViT-B-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), + ), + "ViT-B-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), + ), + "ViT-B-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), + ), + "ViT-B-16-SigLIP-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), + ), + "ViT-L-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), + ), + "ViT-L-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), + ), + "ViT-SO400M-14-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), + ), + "ViT-SO400M-14-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), + ), + + "ViT-L-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), + ), + "ViT-L-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), + ), + "ViT-H-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), + ), + "ViT-H-14-CLIPA-336": dict( + laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), + ), + + "nllb-clip-base": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), + ), + "nllb-clip-large": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), + ), + + "nllb-clip-base-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), + ), + "nllb-clip-large-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), + ) +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/open_clip_local/push_to_hf_hub.py b/open_clip_local/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb8a78b587a585dcf3e3518d66cc00b371e4a82 --- /dev/null +++ b/open_clip_local/push_to_hf_hub.py @@ -0,0 +1,317 @@ +import argparse +import json +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + list_repo_files, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + other_pp = getattr(model.visual, 'preprocess_cfg', {}) + if 'interpolation' in other_pp: + preprocess_cfg['interpolation'] = other_pp['interpolation'] + if 'resize_mode' in other_pp: + preprocess_cfg['resize_mode'] = other_pp['resize_mode'] + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + safe_serialization: Union[bool, str] = 'both', + skip_weights : bool = False, +): + config_filename = HF_CONFIG_NAME + + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, str] = False, +): + if not isinstance(tokenizer, HFTokenizer): + # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + safe_serialization=safe_serialization, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + precision: str = 'fp32', + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + hf_tokenizer_self: bool = False, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + precision=precision, + image_mean=image_mean, + image_std=image_std, + image_interpolation=image_interpolation, + image_resize_mode=image_resize_mode, + ) + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + if hf_tokenizer_self: + # make hf tokenizer config in the uploaded model point to self instead of original location + model_config['text']['hf_tokenizer_name'] = repo_id + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + safe_serialization='both', + ) + + +def generate_readme(model_card: dict, model_name: str): + tags = model_card.pop('tags', ('clip',)) + pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') + readme_text = "---\n" + if tags: + readme_text += "tags:\n" + for t in tags: + readme_text += f"- {t}\n" + readme_text += "library_name: open_clip\n" + readme_text += f"pipeline_tag: {pipeline_tag}\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + parser.add_argument( + '--image-interpolation', + default=None, type=str, choices=['bicubic', 'bilinear', 'random'], + help="image resize interpolation" + ) + parser.add_argument( + '--image-resize-mode', + default=None, type=str, choices=['shortest', 'longest', 'squash'], + help="image resize mode during inference" + ) + parser.add_argument( + "--hf-tokenizer-self", + default=False, + action="store_true", + help="make hf_tokenizer_name point in uploaded config point to itself" + ) + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + precision=args.precision, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + image_interpolation=args.image_interpolation, + image_resize_mode=args.image_resize_mode, + ) + + print(f'{args.model} saved.') diff --git a/open_clip_local/timm_model.py b/open_clip_local/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ddb9a76bf085feeb8c20f3a39a6cfa4c2b643b4 --- /dev/null +++ b/open_clip_local/timm_model.py @@ -0,0 +1,152 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if proj: + assert proj in ("linear", "mlp", "none") + extra_proj = proj in ("linear", "mlp") + if not extra_proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + # if projection is explicitly set to "none" will be pass through from network trunk + proj_dim = 0 if proj == 'none' else embed_dim + self.trunk = timm.create_model( + model_name, + num_classes=proj_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/open_clip_local/tokenizer.py b/open_clip_local/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b762c2fad79c473ae6f254f9ac2f7511000e3dd --- /dev/null +++ b/open_clip_local/tokenizer.py @@ -0,0 +1,517 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +import random +import string +from functools import lru_cache, partial +from typing import Callable, List, Optional, Union +import warnings + +import ftfy +import numpy as np +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +os.environ["TOKENIZERS_PARALLELISM"] = "false" +_nltk_init = False + +DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = " ".join(text.split()) + text = text.strip() + return text + + +def _clean_canonicalize(x): + # basic, remove whitespace, remove punctuation, lower case + return canonicalize_text(basic_clean(x)) + + +def _clean_lower(x): + # basic, remove whitespace, lower case + return whitespace_clean(basic_clean(x)).lower() + + +def _clean_whitespace(x): + # basic, remove whitespace + return whitespace_clean(basic_clean(x)) + + +def get_clean_fn(type: str): + if type == 'canonicalize': + return _clean_canonicalize + elif type == 'lower': + return _clean_lower + elif type == 'whitespace': + return _clean_whitespace + else: + assert False, f"Invalid clean function ({type})." + + +def canonicalize_text( + text, + *, + keep_punctuation_exact_string=None, + trans_punctuation: dict = str.maketrans("", "", string.punctuation), +): + """Returns canonicalized `text` (lowercase and punctuation removed). + + From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + + Args: + text: string to be canonicalized. + keep_punctuation_exact_string: If provided, then this exact string kept. + For example providing '{}' will keep any occurrences of '{}' (but will + still remove '{' and '}' that appear separately). + """ + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(trans_punctuation) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(trans_punctuation) + text = text.lower() + text = " ".join(text.split()) + return text.strip() + + +class SimpleTokenizer(object): + def __init__( + self, + bpe_path: str = default_bpe(), + additional_special_tokens: Optional[List[str]] = None, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'lower', + reduction_mask: str = '' + ): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + special_tokens = ['', ''] + if additional_special_tokens: + special_tokens += additional_special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + self.sot_token_id = self.all_special_ids[0] + self.eot_token_id = self.all_special_ids[1] + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = self.clean_fn(text) + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length' + + if self.reduction_fn is not None: + # use reduction strategy for tokenize if set, otherwise default to truncation below + return self.reduction_fn( + texts, + context_length=context_length, + sot_token_id=self.sot_token_id, + eot_token_id=self.eot_token_id, + encode_fn=self.encode, + ) + + all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = self.eot_token_id + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +_tokenizer = SimpleTokenizer() + + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + + +def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: + return _tokenizer(texts, context_length=context_length) + + +def random_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, + shuffle: bool = False, +): + all_tokens = [encode_fn(text) for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = torch.tensor(tokens) + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + indices = torch.randperm(len(tokens)) + indices = indices[:num_keep] + if not shuffle: + indices = indices.msort() + tokens = tokens[indices] + num_tokens = num_keep + result[i, 0] = sot_token_id + result[i, 1:num_tokens + 1] = tokens + result[i, num_tokens + 1] = eot_token_id + + return result + + +def simple_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +): + all_tokens = [encode_fn(text) for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + start_index = random.randint(0, num_tokens - num_keep) # high is incl + tokens = tokens[start_index: start_index + num_keep] + tokens = [sot_token_id] + tokens + [eot_token_id] + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +def syntax_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s). + Apply syntax masking before tokenize. + """ + import nltk + global _nltk_init + if not _nltk_init: + # run them for the first time + nltk.download('punkt') + nltk.download('averaged_perceptron_tagger') + _nltk_init = True + + def get_order(x): + if x.startswith('NN'): + return 1 + elif x.startswith('JJ'): + return 2 + elif x.startswith('VB'): + return 3 + else: + return 4 + + # syntax masking + new_texts = [] + for text in texts: + list_tokens = nltk.tokenize.word_tokenize(text) + pos_tags = nltk.pos_tag(list_tokens) + # sample the words by get_order method + order_list = [get_order(tag) for _, tag in pos_tags] + sorted_ids = np.argsort(np.array(order_list)) + sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens + sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens + + new_text = '' + for token in sampled_tokens: + new_text = new_text + str(token) + ' ' + new_text = new_text.strip() + new_texts.append(new_text) + texts = new_texts + + all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + # still need first truncate because some words produces two tokens + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token_id + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +def get_reduction_mask_fn(type: str): + """ Choose strategy for dropping (masking) tokens to achieve target context length""" + assert type in ('simple', 'random', 'shuffle', 'syntax') + if type == 'simple': + return simple_mask_tokenize # randomly select block [start:end] + elif type == 'random': + return random_mask_tokenize # randomly drop tokens (keep order) + elif type == 'shuffle': + return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) + elif type == 'syntax': + return syntax_mask_tokenize # randomly drop prioritized by syntax + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__( + self, + tokenizer_name: str, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'whitespace', + strip_sep_token: bool = False, + language: Optional[str] = None, + **kwargs + ): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs) + set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) + if callable(set_lang_fn): + self.set_lang_fn = set_lang_fn + if language is not None: + self.set_language(language) + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.strip_sep_token = strip_sep_token + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + + texts = [self.clean_fn(text) for text in texts] + input_ids = self.tokenizer.batch_encode_plus( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + + if self.strip_sep_token: + input_ids = torch.where( + input_ids == self.tokenizer.sep_token_id, + torch.zeros_like(input_ids), + input_ids, + ) + + return input_ids + + def set_language(self, src_lang): + if hasattr(self, 'set_lang_fn'): + self.set_lang_fn(src_lang) + else: + warnings.warn('Cannot set language for the tokenizer.') + + +class SigLipTokenizer: + """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs + """ + VOCAB_FILES = { + # english, vocab_size=32_000 + "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model", + # used in multilingual models (mT5, PaLI), vocab_size=250_000 + "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", + } + + def __init__( + self, + tokenizer_name: str, + context_length: Optional[int] = 64, + ): + from transformers import T5TokenizerFast + + if tokenizer_name in self.VOCAB_FILES: + # FIXME temporary hack? + import tempfile + + import fsspec + vocab_file = self.VOCAB_FILES[tokenizer_name] + with tempfile.NamedTemporaryFile('wb') as dst: + with fsspec.open(vocab_file, 'rb') as src: + dst.write(src.read()) + self.tokenizer = T5TokenizerFast(dst.name, legacy=False) + else: + self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False) + + self.tokenizer.pad_token_id = 1 + self.tokenizer.eos_token_id = 1 + self.context_length = context_length + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + + texts = [canonicalize_text(basic_clean(text)) for text in texts] + output = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ) + return output.input_ids diff --git a/open_clip_local/transform.py b/open_clip_local/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..f83200e898a6e3662620974e2021fce7276b112b --- /dev/null +++ b/open_clip_local/transform.py @@ -0,0 +1,411 @@ +import numbers +import random +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torchvision.transforms.functional as F +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop, ColorJitter, Grayscale + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .utils import to_2tuple + + +@dataclass +class PreprocessCfg: + size: Union[int, Tuple[int, int]] = 224 + mode: str = 'RGB' + mean: Tuple[float, ...] = OPENAI_DATASET_MEAN + std: Tuple[float, ...] = OPENAI_DATASET_STD + interpolation: str = 'bicubic' + resize_mode: str = 'shortest' + fill_color: int = 0 + + def __post_init__(self): + assert self.mode in ('RGB',) + + @property + def num_channels(self): + return 3 + + @property + def input_size(self): + return (self.num_channels,) + to_2tuple(self.size) + +_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) + + +def merge_preprocess_dict( + base: Union[PreprocessCfg, Dict], + overlay: Dict, +): + """ Merge overlay key-value pairs on top of base preprocess cfg or dict. + Input dicts are filtered based on PreprocessCfg fields. + """ + if isinstance(base, PreprocessCfg): + base_clean = asdict(base) + else: + base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} + if overlay: + overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} + base_clean.update(overlay_clean) + return base_clean + + +def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): + return merge_preprocess_dict(base, kwargs) + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + # params for simclr_jitter_gray + color_jitter_prob: float = None + gray_scale_prob: float = None + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +class ResizeKeepRatio: + """ Resize and Keep Ratio + + Copy & paste from `timm` + """ + + def __init__( + self, + size, + longest=0., + interpolation=InterpolationMode.BICUBIC, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + if isinstance(size, (list, tuple)): + self.size = tuple(size) + else: + self.size = (size, size) + self.interpolation = interpolation + self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest + self.random_scale_prob = random_scale_prob + self.random_scale_range = random_scale_range + self.random_aspect_prob = random_aspect_prob + self.random_aspect_range = random_aspect_range + + @staticmethod + def get_params( + img, + target_size, + longest, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + """Get parameters + """ + source_size = img.size[::-1] # h, w + h, w = source_size + target_h, target_w = target_size + ratio_h = h / target_h + ratio_w = w / target_w + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + if random_scale_prob > 0 and random.random() < random_scale_prob: + ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) + ratio_factor = (ratio_factor, ratio_factor) + else: + ratio_factor = (1., 1.) + if random_aspect_prob > 0 and random.random() < random_aspect_prob: + aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) + ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) + size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] + return size + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size + """ + size = self.get_params( + img, self.size, self.longest, + self.random_scale_prob, self.random_scale_range, + self.random_aspect_prob, self.random_aspect_range + ) + img = F.resize(img, size, self.interpolation) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += f', interpolation={self.interpolation})' + format_string += f', longest={self.longest:.3f})' + return format_string + + +def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: + """Center crops and/or pads the given image. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, + it is used for both directions. + fill (int, Tuple[int]): Padding color + + Returns: + PIL Image or Tensor: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + _, image_height, image_width = F.get_dimensions(img) + crop_height, crop_width = output_size + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = F.pad(img, padding_ltrb, fill=fill) + _, image_height, image_width = F.get_dimensions(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return F.crop(img, crop_top, crop_left, crop_height, crop_width) + + +class CenterCropOrPad(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size, fill=0): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return center_crop_or_pad(img, self.size, fill=self.fill) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +class color_jitter(object): + """ + Apply Color Jitter to the PIL image with a specified probability. + """ + def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): + assert 0. <= p <= 1. + self.p = p + self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) + + def __call__(self, img): + if random.random() < self.p: + return self.transf(img) + else: + return img + + +class gray_scale(object): + """ + Apply Gray Scale to the PIL image with a specified probability. + """ + def __init__(self, p=0.2): + assert 0. <= p <= 1. + self.p = p + self.transf = Grayscale(num_output_channels=3) + + def __call__(self, img): + if random.random() < self.p: + return self.transf(img) + else: + return img + + +def image_transform( + image_size: Union[int, Tuple[int, int]], + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_mode: Optional[str] = None, + interpolation: Optional[str] = None, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + interpolation = interpolation or 'bicubic' + assert interpolation in ['bicubic', 'bilinear', 'random'] + # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set + interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC + + resize_mode = resize_mode or 'shortest' + assert resize_mode in ('shortest', 'longest', 'squash') + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + + normalize = Normalize(mean=mean, std=std) + + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + # drop extra non-timm items + aug_cfg_dict.pop('color_jitter_prob', None) + aug_cfg_dict.pop('gray_scale_prob', None) + + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + interpolation=interpolation, + **aug_cfg_dict, + ) + else: + train_transform = [ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ] + if aug_cfg.color_jitter_prob: + assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 + train_transform.extend([ + color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) + ]) + if aug_cfg.gray_scale_prob: + train_transform.extend([ + gray_scale(aug_cfg.gray_scale_prob) + ]) + train_transform.extend([ + ToTensor(), + normalize, + ]) + train_transform = Compose(train_transform) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_mode == 'longest': + transforms = [ + ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), + CenterCropOrPad(image_size, fill=fill_color) + ] + elif resize_mode == 'squash': + if isinstance(image_size, int): + image_size = (image_size, image_size) + transforms = [ + Resize((256, 256), interpolation=interpolation_mode), + Resize((448, 448), interpolation=interpolation_mode), + # Resize(image_size, interpolation=interpolation_mode), + # Resize((256, 256), interpolation=interpolation_mode), + # CenterCrop((224, 224)), + ] + else: + assert resize_mode == 'shortest' + if not isinstance(image_size, (tuple, list)): + image_size = (image_size, image_size) + if image_size[0] == image_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + transforms = [ + Resize(image_size[0], interpolation=interpolation_mode) + ] + else: + # resize shortest edge to matching target dim for non-square target + transforms = [ResizeKeepRatio(image_size)] + transforms += [CenterCrop(image_size)] + + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) + + +def image_transform_v2( + cfg: PreprocessCfg, + is_train: bool, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + return image_transform( + image_size=cfg.size, + is_train=is_train, + mean=cfg.mean, + std=cfg.std, + interpolation=cfg.interpolation, + resize_mode=cfg.resize_mode, + fill_color=cfg.fill_color, + aug_cfg=aug_cfg, + ) diff --git a/open_clip_local/transformer.py b/open_clip_local/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..000da3e9b857575e1e282a3e6869edd19f84f759 --- /dev/null +++ b/open_clip_local/transformer.py @@ -0,0 +1,913 @@ +from collections import OrderedDict +import math +import numpy as np +from typing import Callable, Optional, Sequence, Tuple +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple +from .pos_embed import get_2d_sincos_pos_embed + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask # need_weights=False, + ) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = q_x + self.ls_1(tmp) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x, attn + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + self.addition_cache = dict() + self.arch, self.attn_strategy, self.gaussian_std = 'reduced', 'naclip', 5.0 + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + + def custom_attn(self, attn_layer, x, return_attn=False, with_attn=False, n_patches=None): + num_heads = attn_layer.num_heads + num_tokens, bsz, embed_dim = x.size() + head_dim = embed_dim // num_heads + scale = head_dim ** -0.5 + + q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if self.attn_strategy in ['naclip', 'nonly', 'gav']: + if n_patches is None: # Assume a rectangular image + n_patches = 2 * (int((num_tokens - 1) ** 0.5),) + addition = self.addition_cache.get(n_patches) + if addition is None: + window_size = [side * 2 - 1 for side in n_patches] + window = Transformer.gaussian_window(*window_size, std=self.gaussian_std) + addition = Transformer.get_attention_addition(*n_patches, window).unsqueeze(0).to(x.dtype).to(x.device) + self.addition_cache[n_patches] = addition + omega = addition.clone() + + if self.attn_strategy == 'naclip': + attn_weights = torch.bmm(k, k.transpose(1, 2)) * scale + elif self.attn_strategy == 'nonly': + attn_weights = torch.zeros((num_heads, num_tokens, num_tokens)).to(x.dtype).to(x.device) + omega = omega * scale * torch.einsum('hop,hPO->hpP', q.norm(dim=2).unsqueeze(1), k.norm(dim=2).unsqueeze(2)).detach() + elif self.attn_strategy == 'gav': + attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale + omega = omega * scale * torch.einsum('hop,hPO->hpP', q.norm(dim=2).unsqueeze(1), k.norm(dim=2).unsqueeze(2)).detach() + + else: + raise NotImplemented + + attn_weights += omega + attn_weights = F.softmax(attn_weights, dim=-1) + + elif self.attn_strategy == 'csa': + q_attn = torch.bmm(q, q.transpose(1, 2)) * scale + k_attn = torch.bmm(k, k.transpose(1, 2)) * scale + attn_weights = F.softmax(q_attn, dim=-1) + F.softmax(k_attn, dim=-1) + elif self.attn_strategy == 'vanilla': + attn_weights = torch.bmm(q * scale, k.transpose(1, 2)) + attn_weights = F.softmax(attn_weights, dim=-1) + else: + raise NotImplemented(f'attn_strategy {self.attn_strategy} is not implemented') + + if return_attn: + return attn_weights + + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) + attn_output = attn_layer.out_proj(attn_output) + + if with_attn: + return attn_output, attn_weights + + return attn_output + + @staticmethod + def gaussian_window(dim1, dim2, std=1.): + constant = 1 / (std * math.sqrt(2)) + ks = list() + for dim in [dim1, dim2]: + start = -(dim - 1) / 2.0 + k = torch.linspace(start=start * constant, + end=(start + (dim - 1)) * constant, + steps=dim, + dtype=torch.float) + ks.append(k) + dist_square_to_mu = (torch.stack(torch.meshgrid(*ks, indexing='ij')) ** 2).sum(0) + return torch.exp(-dist_square_to_mu) + + + @staticmethod + def get_attention_addition(dim1, dim2, window=None, adjust_for_cls=True): + m = torch.einsum('ij,kl->ijkl', torch.eye(dim1), torch.eye(dim2)) + m = m.permute((0, 3, 1, 2)).contiguous() # m[ijkl] = 1 iff (i, j) == (k, l) + out = F.conv2d(m.view(-1, dim1, dim2).unsqueeze(1), window.unsqueeze(0).unsqueeze(1), padding='same').squeeze(1) + out = out.view(dim1 * dim2, dim1 * dim2) + if adjust_for_cls: + v_adjusted = torch.vstack([torch.zeros((1, dim1 * dim2)), out]) + out = torch.hstack([torch.zeros((dim1 * dim2 + 1, 1)), v_adjusted]) + return out + + + def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9], attn_mask: Optional[torch.Tensor] = None): + idx = 0 + out_tokens = [] + out_attn = [] + for r in self.resblocks: + idx += 1 + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + if idx == 12: + x, attn = r(x, attn_mask=attn_mask) + out_attn.append(attn) + elif idx == len(self.resblocks) and x.size(0) != 77: + # x = self.custom_attn(r.attn, r.ln_1(x), n_patches=(16, 16)) + x = self.custom_attn(r.attn, r.ln_1(x), n_patches=(32, 32)) + else: + x, attn_tmp = r(x, attn_mask=attn_mask) + if idx in out_layers: + out_tokens.append(x) + return x, out_attn, out_tokens + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + attentional_pool: bool = False, + attn_pooler_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + no_ln_pre: bool = False, + pos_embed_type: str = 'learnable', + pool_type: str = 'tok', + final_ln_after_pool: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('tok', 'avg', 'none') + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled + self.output_dim = output_dim + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + if pos_embed_type == 'learnable': + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed_type == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1],\ + 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter( + torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) + else: + raise ValueError + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + if attentional_pool: + if isinstance(attentional_pool, str): + self.attn_pool_type = attentional_pool + self.pool_type = 'none' + if attentional_pool in ('parallel', 'cascade'): + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=1, + ) + else: + assert False + else: + self.attn_pool_type = '' + self.pool_type = pool_type + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = None + pool_dim = output_dim + else: + self.attn_pool = None + pool_dim = width + self.pool_type = pool_type + + self.ln_post = norm_layer(pool_dim) + self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.pool_type == 'avg': + pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == 'tok': + pooled, tokens = x[:, 0], x[:, 1:] + else: + pooled = tokens = x + + return pooled, tokens + + def forward(self, x: torch.Tensor, out_layers: list): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attn, patch_tokens = self.transformer(x, out_layers) + B, C, L = attn[0].shape + H = int(np.sqrt(L-1)) + out_attn = torch.zeros([H, H]).to('cuda') + for i in range(len(attn)): + out_attn += attn[i][0, 0, 1:].view(H, H) + x = x.permute(1, 0, 2) # LND -> NLD + patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD + + if self.attn_pool is not None: + if self.attn_pool_contrastive is not None: + # This is untested, WIP pooling that should match paper + x = self.ln_post(x) # TBD LN first or separate one after each pool? + tokens = self.attn_pool(x) + if self.attn_pool_type == 'parallel': + pooled = self.attn_pool_contrastive(x) + else: + assert self.attn_pool_type == 'cascade' + pooled = self.attn_pool_contrastive(tokens) + else: + # this is the original OpenCLIP CoCa setup, does not match paper + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.final_ln_after_pool: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + else: + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, patch_tokens + + return pooled, patch_tokens, tokens @ self.proj + + +def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): + if pool_type == 'first': + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == 'last': + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + + return pooled, tokens + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + output_dim: int = 512, + embed_cls: bool = False, + no_causal_mask: bool = False, + pad_id: int = 0, + pool_type: str = 'argmax', + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('first', 'last', 'argmax', 'none') + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(vocab_size, width) + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) + + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) + if self.text_projection.bias is not None: + nn.init.zeros_(self.text_projection.bias) + else: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_causal_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + if attn_mask is not None: + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + if self.cls_emb is not None: + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled, tokens = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case + else: + x = self.ln_final(x) + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/open_clip_local/transformer_raw.py b/open_clip_local/transformer_raw.py new file mode 100644 index 0000000000000000000000000000000000000000..0eee4c2cc586111862ce931a695beca184f8947c --- /dev/null +++ b/open_clip_local/transformer_raw.py @@ -0,0 +1,803 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple +from .pos_embed import get_2d_sincos_pos_embed + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + attentional_pool: bool = False, + attn_pooler_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + no_ln_pre: bool = False, + pos_embed_type: str = 'learnable', + pool_type: str = 'tok', + final_ln_after_pool: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('tok', 'avg', 'none') + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled + self.output_dim = output_dim + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + if pos_embed_type == 'learnable': + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed_type == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1],\ + 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter( + torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) + else: + raise ValueError + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + if attentional_pool: + if isinstance(attentional_pool, str): + self.attn_pool_type = attentional_pool + self.pool_type = 'none' + if attentional_pool in ('parallel', 'cascade'): + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=1, + ) + else: + assert False + else: + self.attn_pool_type = '' + self.pool_type = pool_type + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = None + pool_dim = output_dim + else: + self.attn_pool = None + pool_dim = width + self.pool_type = pool_type + + self.ln_post = norm_layer(pool_dim) + self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.pool_type == 'avg': + pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == 'tok': + pooled, tokens = x[:, 0], x[:, 1:] + else: + pooled = tokens = x + + return pooled, tokens + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + if self.attn_pool_contrastive is not None: + # This is untested, WIP pooling that should match paper + x = self.ln_post(x) # TBD LN first or separate one after each pool? + tokens = self.attn_pool(x) + if self.attn_pool_type == 'parallel': + pooled = self.attn_pool_contrastive(x) + else: + assert self.attn_pool_type == 'cascade' + pooled = self.attn_pool_contrastive(tokens) + else: + # this is the original OpenCLIP CoCa setup, does not match paper + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.final_ln_after_pool: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + else: + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): + if pool_type == 'first': + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == 'last': + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + + return pooled, tokens + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + output_dim: int = 512, + embed_cls: bool = False, + no_causal_mask: bool = False, + pad_id: int = 0, + pool_type: str = 'argmax', + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('first', 'last', 'argmax', 'none') + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(vocab_size, width) + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) + + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) + if self.text_projection.bias is not None: + nn.init.zeros_(self.text_projection.bias) + else: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_causal_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + if attn_mask is not None: + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + if self.cls_emb is not None: + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled, tokens = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case + else: + x = self.ln_final(x) + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable \ No newline at end of file diff --git a/open_clip_local/transformer_sclip.py b/open_clip_local/transformer_sclip.py new file mode 100644 index 0000000000000000000000000000000000000000..db5b372a9a57e8fdf4bca050398fb27067829dd4 --- /dev/null +++ b/open_clip_local/transformer_sclip.py @@ -0,0 +1,859 @@ +from collections import OrderedDict +import math +import numpy as np +from typing import Callable, Optional, Sequence, Tuple +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple +from .pos_embed import get_2d_sincos_pos_embed + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask + ) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = q_x + self.ls_1(tmp) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x, attn + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + + def custom_attn(self, attn_layer, x, return_attn=False, with_attn=False, csa=False): + + num_heads = attn_layer.num_heads + _, bsz, embed_dim = x.size() + head_dim = embed_dim // num_heads + scale = head_dim ** -0.5 + + q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if csa: + q_attn = torch.bmm(q, q.transpose(1, 2)) * scale + k_attn = torch.bmm(k, k.transpose(1, 2)) * scale + attn_weights = F.softmax(q_attn, dim=-1) + F.softmax(k_attn, dim=-1) + else: + attn_weights = torch.bmm(q * scale, k.transpose(1, 2)) + attn_weights = F.softmax(attn_weights, dim=-1) + + if return_attn: + return attn_weights + + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) + attn_output = attn_layer.out_proj(attn_output) + + if with_attn: + return attn_output, attn_weights + + return attn_output + + + def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9], attn_mask: Optional[torch.Tensor] = None): + idx = 0 + out_tokens = [] + out_attn = [] + for r in self.resblocks: + idx += 1 + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + if idx == 12: + x, attn = r(x, attn_mask=attn_mask) + out_attn.append(attn) + elif idx == len(self.resblocks): + x = x + self.custom_attn(r.attn, r.ln_1(x), csa=True) + x = x + r.mlp(r.ln_2(x)) + else: + x, attn_tmp = r(x, attn_mask=attn_mask) + if idx in out_layers: + out_tokens.append(x) + return x, out_attn, out_tokens + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + attentional_pool: bool = False, + attn_pooler_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + no_ln_pre: bool = False, + pos_embed_type: str = 'learnable', + pool_type: str = 'tok', + final_ln_after_pool: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('tok', 'avg', 'none') + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled + self.output_dim = output_dim + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + if pos_embed_type == 'learnable': + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed_type == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1],\ + 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter( + torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) + else: + raise ValueError + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + if attentional_pool: + if isinstance(attentional_pool, str): + self.attn_pool_type = attentional_pool + self.pool_type = 'none' + if attentional_pool in ('parallel', 'cascade'): + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=1, + ) + else: + assert False + else: + self.attn_pool_type = '' + self.pool_type = pool_type + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = None + pool_dim = output_dim + else: + self.attn_pool = None + pool_dim = width + self.pool_type = pool_type + + self.ln_post = norm_layer(pool_dim) + self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.pool_type == 'avg': + pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == 'tok': + pooled, tokens = x[:, 0], x[:, 1:] + else: + pooled = tokens = x + + return pooled, tokens + + def forward(self, x: torch.Tensor, out_layers: list): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attn, patch_tokens = self.transformer(x, out_layers) + B, C, L = attn[0].shape + H = int(np.sqrt(L-1)) + out_attn = torch.zeros([H, H]).to('cuda') + for i in range(len(attn)): + out_attn += attn[i][0, 0, 1:].view(H, H) + x = x.permute(1, 0, 2) # LND -> NLD + patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD + + if self.attn_pool is not None: + if self.attn_pool_contrastive is not None: + # This is untested, WIP pooling that should match paper + x = self.ln_post(x) # TBD LN first or separate one after each pool? + tokens = self.attn_pool(x) + if self.attn_pool_type == 'parallel': + pooled = self.attn_pool_contrastive(x) + else: + assert self.attn_pool_type == 'cascade' + pooled = self.attn_pool_contrastive(tokens) + else: + # this is the original OpenCLIP CoCa setup, does not match paper + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.final_ln_after_pool: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + else: + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, patch_tokens + + return pooled, patch_tokens, tokens @ self.proj + + +def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): + if pool_type == 'first': + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == 'last': + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + + return pooled, tokens + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + output_dim: int = 512, + embed_cls: bool = False, + no_causal_mask: bool = False, + pad_id: int = 0, + pool_type: str = 'argmax', + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('first', 'last', 'argmax', 'none') + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(vocab_size, width) + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) + + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) + if self.text_projection.bias is not None: + nn.init.zeros_(self.text_projection.bias) + else: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_causal_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + if attn_mask is not None: + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + # x = self.transformer(x, attn_mask=attn_mask) + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + if self.cls_emb is not None: + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled, tokens = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case + else: + x = self.ln_final(x) + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/open_clip_local/utils.py b/open_clip_local/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/open_clip_local/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/open_clip_local/version.py b/open_clip_local/version.py new file mode 100644 index 0000000000000000000000000000000000000000..78afda8502b16f06c6a1b8a9f97f48ee0db9f6ce --- /dev/null +++ b/open_clip_local/version.py @@ -0,0 +1 @@ +__version__ = '2.24.0' diff --git a/open_clip_local/zero_shot_classifier.py b/open_clip_local/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..535ec9696d27a1dcbe2c43da18f5fd20b599cb9b --- /dev/null +++ b/open_clip_local/zero_shot_classifier.py @@ -0,0 +1,110 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + texts = tokenizer(texts).to(device) + class_embeddings = model.encode_text(texts, normalize=True) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/open_clip_local/zero_shot_metadata.py b/open_clip_local/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb452bbb6e27b71cff1dd27e2bb263259b9363f --- /dev/null +++ b/open_clip_local/zero_shot_metadata.py @@ -0,0 +1,266 @@ + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +) +