zhiqing0205 commited on
Commit
3de7bf6
·
1 Parent(s): 74acc06

Add core libraries: anomalib, dinov2, open_clip_local

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. anomalib/__init__.py +24 -0
  2. anomalib/callbacks/__init__.py +64 -0
  3. anomalib/callbacks/checkpoint.py +58 -0
  4. anomalib/callbacks/graph.py +61 -0
  5. anomalib/callbacks/metrics.py +226 -0
  6. anomalib/callbacks/model_loader.py +39 -0
  7. anomalib/callbacks/nncf/__init__.py +4 -0
  8. anomalib/callbacks/nncf/callback.py +106 -0
  9. anomalib/callbacks/nncf/utils.py +243 -0
  10. anomalib/callbacks/normalization/__init__.py +12 -0
  11. anomalib/callbacks/normalization/base.py +29 -0
  12. anomalib/callbacks/normalization/min_max_normalization.py +109 -0
  13. anomalib/callbacks/normalization/utils.py +78 -0
  14. anomalib/callbacks/post_processor.py +125 -0
  15. anomalib/callbacks/thresholding.py +197 -0
  16. anomalib/callbacks/tiler_configuration.py +74 -0
  17. anomalib/callbacks/timer.py +109 -0
  18. anomalib/callbacks/visualizer.py +182 -0
  19. anomalib/cli/__init__.py +8 -0
  20. anomalib/cli/cli.py +483 -0
  21. anomalib/cli/install.py +81 -0
  22. anomalib/cli/utils/__init__.py +8 -0
  23. anomalib/cli/utils/help_formatter.py +268 -0
  24. anomalib/cli/utils/installation.py +430 -0
  25. anomalib/cli/utils/openvino.py +32 -0
  26. anomalib/data/__init__.py +72 -0
  27. anomalib/data/base/__init__.py +18 -0
  28. anomalib/data/base/datamodule.py +305 -0
  29. anomalib/data/base/dataset.py +208 -0
  30. anomalib/data/base/depth.py +76 -0
  31. anomalib/data/base/video.py +213 -0
  32. anomalib/data/depth/__init__.py +20 -0
  33. anomalib/data/depth/folder_3d.py +433 -0
  34. anomalib/data/depth/mvtec_3d.py +302 -0
  35. anomalib/data/errors.py +19 -0
  36. anomalib/data/image/__init__.py +33 -0
  37. anomalib/data/image/btech.py +362 -0
  38. anomalib/data/image/folder.py +478 -0
  39. anomalib/data/image/kolektor.py +342 -0
  40. anomalib/data/image/mvtec.py +414 -0
  41. anomalib/data/image/mvtec_loco.py +480 -0
  42. anomalib/data/image/visa.py +364 -0
  43. anomalib/data/predict.py +52 -0
  44. anomalib/data/transforms/__init__.py +8 -0
  45. anomalib/data/transforms/center_crop.py +87 -0
  46. anomalib/data/utils/__init__.py +56 -0
  47. anomalib/data/utils/augmenter.py +172 -0
  48. anomalib/data/utils/boxes.py +117 -0
  49. anomalib/data/utils/download.py +364 -0
  50. anomalib/data/utils/generators/__init__.py +8 -0
anomalib/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib library for research and benchmarking."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from enum import Enum
7
+
8
+ __version__ = "1.1.0dev"
9
+
10
+
11
+ class LearningType(str, Enum):
12
+ """Learning type defining how the model learns from the dataset samples."""
13
+
14
+ ONE_CLASS = "one_class"
15
+ ZERO_SHOT = "zero_shot"
16
+ FEW_SHOT = "few_shot"
17
+
18
+
19
+ class TaskType(str, Enum):
20
+ """Task type used when generating predictions on the dataset."""
21
+
22
+ CLASSIFICATION = "classification"
23
+ DETECTION = "detection"
24
+ SEGMENTATION = "segmentation"
anomalib/callbacks/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callbacks for Anomalib models."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import logging
8
+ from importlib import import_module
9
+ from pathlib import Path
10
+
11
+ import yaml
12
+ from jsonargparse import Namespace
13
+ from lightning.pytorch.callbacks import Callback
14
+ from omegaconf import DictConfig, ListConfig, OmegaConf
15
+
16
+ from .checkpoint import ModelCheckpoint
17
+ from .graph import GraphLogger
18
+ from .model_loader import LoadModelCallback
19
+ from .tiler_configuration import TilerConfigurationCallback
20
+ from .timer import TimerCallback
21
+
22
+ __all__ = [
23
+ "ModelCheckpoint",
24
+ "GraphLogger",
25
+ "LoadModelCallback",
26
+ "TilerConfigurationCallback",
27
+ "TimerCallback",
28
+ ]
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]:
35
+ """Return base callbacks for all the lightning models.
36
+
37
+ Args:
38
+ config (DictConfig | ListConfig | Namespace): Model config
39
+
40
+ Return:
41
+ (list[Callback]): List of callbacks.
42
+ """
43
+ logger.info("Loading the callbacks")
44
+
45
+ callbacks: list[Callback] = []
46
+
47
+ if "ckpt_path" in config.trainer and config.ckpt_path is not None:
48
+ load_model = LoadModelCallback(config.ckpt_path)
49
+ callbacks.append(load_model)
50
+
51
+ if "optimization" in config and "nncf" in config.optimization and config.optimization.nncf.apply:
52
+ # NNCF wraps torch's jit which conflicts with kornia's jit calls.
53
+ # Hence, nncf is imported only when required
54
+ nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
55
+ nncf_callback = nncf_module.NNCFCallback
56
+ nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
57
+ callbacks.append(
58
+ nncf_callback(
59
+ config=nncf_config,
60
+ export_dir=str(Path(config.project.path) / "compressed"),
61
+ ),
62
+ )
63
+
64
+ return callbacks
anomalib/callbacks/checkpoint.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib Model Checkpoint Callback."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from lightning.pytorch import Trainer
7
+ from lightning.pytorch.callbacks import ModelCheckpoint as LightningCheckpoint
8
+ from lightning.pytorch.trainer.states import TrainerFn
9
+
10
+ from anomalib import LearningType
11
+
12
+
13
+ class ModelCheckpoint(LightningCheckpoint):
14
+ """Anomalib Model Checkpoint Callback.
15
+
16
+ This class overrides the Lightning ModelCheckpoint callback to enable saving checkpoints without running any
17
+ training steps. This is useful for zero-/few-shot models, where the fit sequence only consists of validation.
18
+
19
+ To enable saving checkpoints without running any training steps, we need to override two checks which are being
20
+ called in the ``on_validation_end`` method of the parent class:
21
+ - ``_should_save_on_train_epoch_end``: This method checks whether the checkpoint should be saved at the end of a
22
+ training epoch, or at the end of the validation sequence. We modify this method to default to saving at the end
23
+ of the validation sequence when the model is of zero- or few-shot type, unless ``save_on_train_epoch_end`` is
24
+ specifically set by the user.
25
+ - ``_should_skip_saving_checkpoint``: This method checks whether the checkpoint should be saved at all. We modify
26
+ this method to allow saving during both the ``FITTING`` and ``VALIDATING`` states. In addition, we allow saving
27
+ if the global step has not changed since the last checkpoint, but only for zero- and few-shot models. This is
28
+ needed because both the last global step and the last checkpoint remain unchanged during zero-/few-shot
29
+ training, which would otherwise prevent saving checkpoints during validation.
30
+ """
31
+
32
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
33
+ """Checks whether the checkpoint should be saved.
34
+
35
+ Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow
36
+ saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models).
37
+ """
38
+ is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
39
+ return (
40
+ bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
41
+ or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit
42
+ or trainer.sanity_checking # don't save anything during sanity check
43
+ or (self._last_global_step_saved == trainer.global_step and not is_zero_or_few_shot)
44
+ )
45
+
46
+ def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool:
47
+ """Checks whether the checkpoint should be saved at the end of a training epoch or validation sequence.
48
+
49
+ Overrides the parent method to default to saving at the end of the validation sequence when the model is of
50
+ zero- or few-shot type, unless ``save_on_train_epoch_end`` is specifically set by the user.
51
+ """
52
+ if self._save_on_train_epoch_end is not None:
53
+ return self._save_on_train_epoch_end
54
+
55
+ if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
56
+ return False
57
+
58
+ return super()._should_save_on_train_epoch_end(trainer)
anomalib/callbacks/graph.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Log model graph to respective logger."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import torch
7
+ from lightning.pytorch import Callback, LightningModule, Trainer
8
+
9
+ from anomalib.loggers import AnomalibCometLogger, AnomalibTensorBoardLogger, AnomalibWandbLogger
10
+
11
+
12
+ class GraphLogger(Callback):
13
+ """Log model graph to respective logger.
14
+
15
+ Examples:
16
+ Log model graph to Tensorboard
17
+
18
+ >>> from anomalib.callbacks import GraphLogger
19
+ >>> from anomalib.loggers import AnomalibTensorBoardLogger
20
+ >>> from anomalib.engine import Engine
21
+ ...
22
+ >>> logger = AnomalibTensorBoardLogger()
23
+ >>> callbacks = [GraphLogger()]
24
+ >>> engine = Engine(logger=logger, callbacks=callbacks)
25
+
26
+ Log model graph to Comet
27
+
28
+ >>> from anomalib.loggers import AnomalibCometLogger
29
+ >>> from anomalib.engine import Engine
30
+ ...
31
+ >>> logger = AnomalibCometLogger()
32
+ >>> callbacks = [GraphLogger()]
33
+ >>> engine = Engine(logger=logger, callbacks=callbacks)
34
+ """
35
+
36
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
37
+ """Log model graph to respective logger.
38
+
39
+ Args:
40
+ trainer: Trainer object which contans reference to loggers.
41
+ pl_module: LightningModule object which is logged.
42
+ """
43
+ for logger in trainer.loggers:
44
+ if isinstance(logger, AnomalibWandbLogger):
45
+ # NOTE: log graph gets populated only after one backward pass. This won't work for models which do not
46
+ # require training such as Padim
47
+ logger.watch(pl_module, log_graph=True, log="all")
48
+ break
49
+
50
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
51
+ """Unwatch model if configured for wandb and log it model graph in Tensorboard if specified.
52
+
53
+ Args:
54
+ trainer: Trainer object which contans reference to loggers.
55
+ pl_module: LightningModule object which is logged.
56
+ """
57
+ for logger in trainer.loggers:
58
+ if isinstance(logger, AnomalibCometLogger | AnomalibTensorBoardLogger):
59
+ logger.log_graph(pl_module, input_array=torch.ones((1, 3, 256, 256)))
60
+ elif isinstance(logger, AnomalibWandbLogger):
61
+ logger.experiment.unwatch(pl_module)
anomalib/callbacks/metrics.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MetricsManager callback."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import logging
8
+ from enum import Enum
9
+ from typing import Any
10
+
11
+ import torch
12
+ from lightning.pytorch import Callback, Trainer
13
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
14
+
15
+ from anomalib import TaskType
16
+ from anomalib.metrics import create_metric_collection
17
+ from anomalib.models import AnomalyModule
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Device(str, Enum):
23
+ """Device on which to compute metrics."""
24
+
25
+ CPU = "cpu"
26
+ GPU = "gpu"
27
+
28
+
29
+ class _MetricsCallback(Callback):
30
+ """Create image and pixel-level AnomalibMetricsCollection.
31
+
32
+ This callback creates AnomalibMetricsCollection based on the
33
+ list of strings provided for image and pixel-level metrics.
34
+ After these MetricCollections are created, the callback assigns
35
+ these to the lightning module.
36
+
37
+ Args:
38
+ task (TaskType | str): Task type of the current run.
39
+ image_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of image-level metrics.
40
+ pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of pixel-level metrics.
41
+ device (str): Whether to compute metrics on cpu or gpu. Defaults to cpu.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ task: TaskType | str = TaskType.SEGMENTATION,
47
+ image_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None,
48
+ pixel_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None,
49
+ device: Device = Device.CPU,
50
+ ) -> None:
51
+ super().__init__()
52
+ self.task = TaskType(task)
53
+ self.image_metric_names = image_metrics
54
+ self.pixel_metric_names = pixel_metrics
55
+ self.device = device
56
+
57
+ def setup(
58
+ self,
59
+ trainer: Trainer,
60
+ pl_module: AnomalyModule,
61
+ stage: str | None = None,
62
+ ) -> None:
63
+ """Set image and pixel-level AnomalibMetricsCollection within Anomalib Model.
64
+
65
+ Args:
66
+ trainer (pl.Trainer): PyTorch Lightning Trainer
67
+ pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule.
68
+ stage (str | None, optional): fit, validate, test or predict. Defaults to None.
69
+ """
70
+ del stage, trainer # this variable is not used.
71
+ image_metric_names = [] if self.image_metric_names is None else self.image_metric_names
72
+ if isinstance(image_metric_names, str):
73
+ image_metric_names = [image_metric_names]
74
+
75
+ pixel_metric_names: list[str] | dict[str, dict[str, Any]]
76
+ if self.pixel_metric_names is None:
77
+ pixel_metric_names = []
78
+ elif self.task == TaskType.CLASSIFICATION:
79
+ pixel_metric_names = []
80
+ logger.warning(
81
+ "Cannot perform pixel-level evaluation when task type is classification. "
82
+ "Ignoring the following pixel-level metrics: %s",
83
+ self.pixel_metric_names,
84
+ )
85
+ else:
86
+ pixel_metric_names = (
87
+ self.pixel_metric_names.copy()
88
+ if not isinstance(self.pixel_metric_names, str)
89
+ else [self.pixel_metric_names]
90
+ )
91
+
92
+ # create a separate metric collection for metrics that operate over the semantic segmentation mask
93
+ # (segmentation mask with a separate channel for each defect type)
94
+ semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = []
95
+ # currently only SPRO metric is supported as semantic segmentation metric
96
+ if "SPRO" in pixel_metric_names:
97
+ if isinstance(pixel_metric_names, list):
98
+ pixel_metric_names.remove("SPRO")
99
+ semantic_pixel_metric_names = ["SPRO"]
100
+ elif isinstance(pixel_metric_names, dict):
101
+ spro_metric = pixel_metric_names.pop("SPRO")
102
+ semantic_pixel_metric_names = {"SPRO": spro_metric}
103
+ else:
104
+ logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names))
105
+
106
+ if isinstance(pl_module, AnomalyModule):
107
+ pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
108
+ if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint
109
+ new_metrics = create_metric_collection(pixel_metric_names)
110
+ for name in new_metrics:
111
+ if name not in pl_module.pixel_metrics:
112
+ pl_module.pixel_metrics.add_metrics(new_metrics[name])
113
+ else:
114
+ pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
115
+ pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_")
116
+ self._set_threshold(pl_module)
117
+
118
+ def on_validation_epoch_start(
119
+ self,
120
+ trainer: Trainer,
121
+ pl_module: AnomalyModule,
122
+ ) -> None:
123
+ del trainer # Unused argument.
124
+
125
+ pl_module.image_metrics.reset()
126
+ pl_module.pixel_metrics.reset()
127
+ pl_module.semantic_pixel_metrics.reset()
128
+
129
+ def on_validation_batch_end(
130
+ self,
131
+ trainer: Trainer,
132
+ pl_module: AnomalyModule,
133
+ outputs: STEP_OUTPUT | None,
134
+ batch: Any, # noqa: ANN401
135
+ batch_idx: int,
136
+ dataloader_idx: int = 0,
137
+ ) -> None:
138
+ del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
139
+
140
+ if outputs is not None:
141
+ self._outputs_to_device(outputs)
142
+ self._update_metrics(pl_module, outputs)
143
+
144
+ def on_validation_epoch_end(
145
+ self,
146
+ trainer: Trainer,
147
+ pl_module: AnomalyModule,
148
+ ) -> None:
149
+ del trainer # Unused argument.
150
+
151
+ self._set_threshold(pl_module)
152
+ self._log_metrics(pl_module)
153
+
154
+ def on_test_epoch_start(
155
+ self,
156
+ trainer: Trainer,
157
+ pl_module: AnomalyModule,
158
+ ) -> None:
159
+ del trainer # Unused argument.
160
+
161
+ pl_module.image_metrics.reset()
162
+ pl_module.pixel_metrics.reset()
163
+ pl_module.semantic_pixel_metrics.reset()
164
+
165
+ def on_test_batch_end(
166
+ self,
167
+ trainer: Trainer,
168
+ pl_module: AnomalyModule,
169
+ outputs: STEP_OUTPUT | None,
170
+ batch: Any, # noqa: ANN401
171
+ batch_idx: int,
172
+ dataloader_idx: int = 0,
173
+ ) -> None:
174
+ del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
175
+
176
+ if outputs is not None:
177
+ self._outputs_to_device(outputs)
178
+ self._update_metrics(pl_module, outputs)
179
+
180
+ def on_test_epoch_end(
181
+ self,
182
+ trainer: Trainer,
183
+ pl_module: AnomalyModule,
184
+ ) -> None:
185
+ del trainer # Unused argument.
186
+
187
+ self._log_metrics(pl_module)
188
+
189
+ def _set_threshold(self, pl_module: AnomalyModule) -> None:
190
+ pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
191
+ pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
192
+ pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
193
+
194
+ def _update_metrics(
195
+ self,
196
+ pl_module: AnomalyModule,
197
+ output: STEP_OUTPUT,
198
+ ) -> None:
199
+ pl_module.image_metrics.to(self.device)
200
+ pl_module.image_metrics.update(output["pred_scores"], output["label"].int())
201
+ if "mask" in output and "anomaly_maps" in output:
202
+ pl_module.pixel_metrics.to(self.device)
203
+ pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
204
+ if "semantic_mask" in output and "anomaly_maps" in output:
205
+ pl_module.semantic_pixel_metrics.to(self.device)
206
+ pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"])
207
+
208
+ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
209
+ if isinstance(output, dict):
210
+ for key, value in output.items():
211
+ output[key] = self._outputs_to_device(value)
212
+ elif isinstance(output, torch.Tensor):
213
+ output = output.to(self.device)
214
+ elif isinstance(output, list):
215
+ for i, value in enumerate(output):
216
+ output[i] = self._outputs_to_device(value)
217
+ return output
218
+
219
+ @staticmethod
220
+ def _log_metrics(pl_module: AnomalyModule) -> None:
221
+ """Log computed performance metrics."""
222
+ pl_module.log_dict(pl_module.image_metrics, prog_bar=True)
223
+ if pl_module.pixel_metrics.update_called:
224
+ pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False)
225
+ if pl_module.semantic_pixel_metrics.update_called:
226
+ pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False)
anomalib/callbacks/model_loader.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callback that loads model weights from the state dict."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from lightning.pytorch import Callback, Trainer
11
+
12
+ from anomalib.models.components import AnomalyModule
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class LoadModelCallback(Callback):
18
+ """Callback that loads the model weights from the state dict.
19
+
20
+ Examples:
21
+ >>> from anomalib.callbacks import LoadModelCallback
22
+ >>> from anomalib.engine import Engine
23
+ ...
24
+ >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")]
25
+ >>> engine = Engine(callbacks=callbacks)
26
+ """
27
+
28
+ def __init__(self, weights_path: str) -> None:
29
+ self.weights_path = weights_path
30
+
31
+ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
32
+ """Call when inference begins.
33
+
34
+ Loads the model weights from ``weights_path`` into the PyTorch module.
35
+ """
36
+ del trainer, stage # These variables are not used.
37
+
38
+ logger.info("Loading the model from %s", self.weights_path)
39
+ pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])
anomalib/callbacks/nncf/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Integration NNCF."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
anomalib/callbacks/nncf/callback.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callbacks for NNCF optimization."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import subprocess # nosec B404
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import lightning.pytorch as pl
12
+ from lightning.pytorch import Callback
13
+ from nncf import NNCFConfig
14
+ from nncf.torch import register_default_init_args
15
+
16
+ from anomalib.callbacks.nncf.utils import InitLoader, wrap_nncf_model
17
+
18
+ if TYPE_CHECKING:
19
+ from nncf.api.compression import CompressionAlgorithmController
20
+
21
+
22
+ class NNCFCallback(Callback):
23
+ """Callback for NNCF compression.
24
+
25
+ Assumes that the pl module contains a 'model' attribute, which is
26
+ the PyTorch module that must be compressed.
27
+
28
+ Args:
29
+ config (dict): NNCF Configuration
30
+ export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved.
31
+ If None model will not be exported.
32
+ """
33
+
34
+ def __init__(self, config: dict, export_dir: str | None = None) -> None:
35
+ self.export_dir = export_dir
36
+ self.config = NNCFConfig(config)
37
+ self.nncf_ctrl: CompressionAlgorithmController | None = None
38
+
39
+ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None:
40
+ """Call when fit or test begins.
41
+
42
+ Takes the pytorch model and wraps it using the compression controller
43
+ so that it is ready for nncf fine-tuning.
44
+ """
45
+ del stage # `stage` variable is not used.
46
+
47
+ if self.nncf_ctrl is not None:
48
+ return
49
+
50
+ # Get validate subset to initialize quantization,
51
+ # because train subset does not contain anomalous images.
52
+ init_loader = InitLoader(trainer.datamodule.val_dataloader())
53
+ config = register_default_init_args(self.config, init_loader)
54
+
55
+ self.nncf_ctrl, pl_module.model = wrap_nncf_model(
56
+ model=pl_module.model,
57
+ config=config,
58
+ dataloader=trainer.datamodule.train_dataloader(),
59
+ init_state_dict=None, # type: ignore[arg-type]
60
+ )
61
+
62
+ def on_train_batch_start(
63
+ self,
64
+ trainer: pl.Trainer,
65
+ pl_module: pl.LightningModule,
66
+ batch: Any, # noqa: ANN401
67
+ batch_idx: int,
68
+ unused: int = 0,
69
+ ) -> None:
70
+ """Call when the train batch begins.
71
+
72
+ Prepare compression method to continue training the model in the next step.
73
+ """
74
+ del trainer, pl_module, batch, batch_idx, unused # These variables are not used.
75
+
76
+ if self.nncf_ctrl:
77
+ self.nncf_ctrl.scheduler.step()
78
+
79
+ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
80
+ """Call when the train epoch starts.
81
+
82
+ Prepare compression method to continue training the model in the next epoch.
83
+ """
84
+ del trainer, pl_module # `trainer` and `pl_module` variables are not used.
85
+
86
+ if self.nncf_ctrl:
87
+ self.nncf_ctrl.scheduler.epoch_step()
88
+
89
+ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
90
+ """Call when the train ends.
91
+
92
+ Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR.
93
+ """
94
+ del trainer, pl_module # `trainer` and `pl_module` variables are not used.
95
+
96
+ if self.export_dir is None or self.nncf_ctrl is None:
97
+ return
98
+
99
+ Path(self.export_dir).mkdir(parents=True, exist_ok=True)
100
+ onnx_path = str(Path(self.export_dir) / "model_nncf.onnx")
101
+ self.nncf_ctrl.export_model(onnx_path)
102
+
103
+ optimize_command = ["mo", "--input_model", onnx_path, "--output_dir", self.export_dir]
104
+ # TODO(samet-akcay): Check if mo can be done via python API
105
+ # CVS-122665
106
+ subprocess.run(optimize_command, check=True) # noqa: S603 # nosec B603
anomalib/callbacks/nncf/utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for NNCf optimization."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import logging
8
+ from copy import copy
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import torch
12
+ from nncf import NNCFConfig
13
+ from nncf.api.compression import CompressionAlgorithmController
14
+ from nncf.torch import create_compressed_model, load_state, register_default_init_args
15
+ from nncf.torch.initialization import PTInitializingDataLoader
16
+ from nncf.torch.nncf_network import NNCFNetwork
17
+ from torch import nn
18
+ from torch.utils.data.dataloader import DataLoader
19
+
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Iterator
22
+
23
+
24
+ logger = logging.getLogger(name="NNCF compression")
25
+
26
+
27
+ class InitLoader(PTInitializingDataLoader):
28
+ """Initializing data loader for NNCF to be used with unsupervised training algorithms."""
29
+
30
+ def __init__(self, data_loader: DataLoader) -> None:
31
+ super().__init__(data_loader)
32
+ self._data_loader_iter: Iterator
33
+
34
+ def __iter__(self) -> "InitLoader":
35
+ """Create iterator for dataloader."""
36
+ self._data_loader_iter = iter(self._data_loader)
37
+ return self
38
+
39
+ def __next__(self) -> torch.Tensor:
40
+ """Return next item from dataloader iterator."""
41
+ loaded_item = next(self._data_loader_iter)
42
+ return loaded_item["image"]
43
+
44
+ def get_inputs(self, dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]:
45
+ """Get input to model.
46
+
47
+ Returns:
48
+ (dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during
49
+ the initialization process
50
+ """
51
+ return (dataloader_output,), {}
52
+
53
+ def get_target(self, _): # noqa: ANN001, ANN201
54
+ """Return structure for ground truth in loss criterion based on dataloader output.
55
+
56
+ This implementation does not do anything and is a placeholder.
57
+
58
+ Returns:
59
+ None
60
+ """
61
+ return
62
+
63
+
64
+ def wrap_nncf_model(
65
+ model: nn.Module,
66
+ config: dict,
67
+ dataloader: DataLoader,
68
+ init_state_dict: dict,
69
+ ) -> tuple[CompressionAlgorithmController, NNCFNetwork]:
70
+ """Wrap model by NNCF.
71
+
72
+ :param model: Anomalib model.
73
+ :param config: NNCF config.
74
+ :param dataloader: Dataloader for initialization of NNCF model.
75
+ :param init_state_dict: Opti
76
+ :return: compression controller, compressed model
77
+ """
78
+ nncf_config = NNCFConfig.from_dict(config)
79
+
80
+ if not dataloader and not init_state_dict:
81
+ logger.warning(
82
+ "Either dataloader or NNCF pre-trained "
83
+ "model checkpoint should be set. Without this, "
84
+ "quantizers will not be initialized",
85
+ )
86
+
87
+ compression_state = None
88
+ resuming_state_dict = None
89
+ if init_state_dict:
90
+ resuming_state_dict = init_state_dict.get("model")
91
+ compression_state = init_state_dict.get("compression_state")
92
+
93
+ if dataloader:
94
+ init_loader = InitLoader(dataloader)
95
+ nncf_config = register_default_init_args(nncf_config, init_loader)
96
+
97
+ nncf_ctrl, nncf_model = create_compressed_model(
98
+ model=model,
99
+ config=nncf_config,
100
+ dump_graphs=False,
101
+ compression_state=compression_state,
102
+ )
103
+
104
+ if resuming_state_dict:
105
+ load_state(nncf_model, resuming_state_dict, is_resume=True)
106
+
107
+ return nncf_ctrl, nncf_model
108
+
109
+
110
+ def is_state_nncf(state: dict) -> bool:
111
+ """Check if state is the result of NNCF-compressed model."""
112
+ return bool(state.get("meta", {}).get("nncf_enable_compression", False))
113
+
114
+
115
+ def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict:
116
+ """Compose NNCf config by selected options.
117
+
118
+ :param nncf_config:
119
+ :param enabled_options:
120
+ :return: config
121
+ """
122
+ optimisation_parts = nncf_config
123
+ optimisation_parts_to_choose = []
124
+ if "order_of_parts" in optimisation_parts:
125
+ # The result of applying the changes from optimisation parts
126
+ # may depend on the order of applying the changes
127
+ # (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
128
+ # but for sparsity it is required `total_epochs=50`)
129
+ # So, user can define `order_of_parts` in the optimisation_config
130
+ # to specify the order of applying the parts.
131
+ order_of_parts = optimisation_parts["order_of_parts"]
132
+ if not isinstance(order_of_parts, list):
133
+ msg = 'The field "order_of_parts" in optimization config should be a list'
134
+ raise TypeError(msg)
135
+
136
+ for part in enabled_options:
137
+ if part not in order_of_parts:
138
+ msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}"
139
+ raise ValueError(msg)
140
+
141
+ optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]
142
+
143
+ if "base" not in optimisation_parts:
144
+ msg = 'Error: the optimisation config does not contain the "base" part'
145
+ raise KeyError(msg)
146
+ nncf_config_part = optimisation_parts["base"]
147
+
148
+ for part in optimisation_parts_to_choose:
149
+ if part not in optimisation_parts:
150
+ msg = f'Error: the optimisation config does not contain the part "{part}"'
151
+ raise KeyError(msg)
152
+ optimisation_part_dict = optimisation_parts[part]
153
+ try:
154
+ nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
155
+ except AssertionError as cur_error:
156
+ err_descr = (
157
+ f"Error during merging the parts of nncf configs:\n"
158
+ f"the current part={part}, "
159
+ f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
160
+ f"The error is:\n{cur_error}"
161
+ )
162
+ raise RuntimeError(err_descr) from None
163
+
164
+ return nncf_config_part
165
+
166
+
167
+ def merge_dicts_and_lists_b_into_a(
168
+ a: dict[Any, Any] | list[Any],
169
+ b: dict[Any, Any] | list[Any],
170
+ ) -> dict[Any, Any] | list[Any]:
171
+ """Merge dict configs.
172
+
173
+ Args:
174
+ a (dict[Any, Any] | list[Any]): First dict or list.
175
+ b (dict[Any, Any] | list[Any]): Second dict or list.
176
+
177
+ Returns:
178
+ dict[Any, Any] | list[Any]: Merged dict or list.
179
+ """
180
+ return _merge_dicts_and_lists_b_into_a(a, b, "")
181
+
182
+
183
+ def _merge_dicts_and_lists_b_into_a(
184
+ a: dict[Any, Any] | list[Any],
185
+ b: dict[Any, Any] | list[Any],
186
+ cur_key: int | str | None = None,
187
+ ) -> dict[Any, Any] | list[Any]:
188
+ """Merge dict configs.
189
+
190
+ * works with usual dicts and lists and derived types
191
+ * supports merging of lists (by concatenating the lists)
192
+ * makes recursive merging for dict + dict case
193
+ * overwrites when merging scalar into scalar
194
+ Note that we merge b into a (whereas Config makes merge a into b),
195
+ since otherwise the order of list merging is counter-intuitive.
196
+
197
+ Args:
198
+ a (dict[Any, Any] | list[Any]): First dict or list.
199
+ b (dict[Any, Any] | list[Any]): Second dict or list.
200
+ cur_key (int | str | None, optional): key for current level of recursion. Defaults to None.
201
+
202
+ Returns:
203
+ dict[Any, Any] | list[Any]: Merged dict or list.
204
+ """
205
+
206
+ def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str:
207
+ _key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`"
208
+ return (
209
+ f"Error in merging parts of config: different types {_key_str},"
210
+ f" type(a) = {type(_a)},"
211
+ f" type(b) = {type(_b)}"
212
+ )
213
+
214
+ if not (isinstance(a, dict | list)):
215
+ msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}"
216
+ raise TypeError(msg)
217
+
218
+ if not (isinstance(b, dict | list)):
219
+ raise TypeError(_err_str(a, b, cur_key))
220
+
221
+ if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)):
222
+ raise TypeError(_err_str(a, b, cur_key))
223
+
224
+ if isinstance(a, list) and isinstance(b, list):
225
+ # the main diff w.r.t. mmcf.Config -- merging of lists
226
+ return a + b
227
+
228
+ a = copy(a)
229
+ for k in b:
230
+ if k not in a:
231
+ a[k] = copy(b[k])
232
+ continue
233
+ new_cur_key = str(cur_key) + "." + k if cur_key else k
234
+ if isinstance(a[k], dict | list):
235
+ a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
236
+ continue
237
+
238
+ if any(isinstance(b[k], t) for t in [dict, list]):
239
+ raise TypeError(_err_str(a[k], b[k], new_cur_key))
240
+
241
+ # suppose here that a[k] and b[k] are scalars, just overwrite
242
+ a[k] = b[k]
243
+ return a
anomalib/callbacks/normalization/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalization callbacks.
2
+
3
+ Note: These callbacks are used within the Engine.
4
+ """
5
+
6
+ # Copyright (C) 2023-2024 Intel Corporation
7
+ # SPDX-License-Identifier: Apache-2.0
8
+
9
+ from .min_max_normalization import _MinMaxNormalizationCallback
10
+ from .utils import get_normalization_callback
11
+
12
+ __all__ = ["get_normalization_callback", "_MinMaxNormalizationCallback"]
anomalib/callbacks/normalization/base.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base Normalization Callback."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ from lightning.pytorch import Callback
9
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
10
+
11
+ from anomalib.models.components import AnomalyModule
12
+
13
+
14
+ class NormalizationCallback(Callback, ABC):
15
+ """Base normalization callback."""
16
+
17
+ @staticmethod
18
+ @abstractmethod
19
+ def _normalize_batch(batch: STEP_OUTPUT, pl_module: AnomalyModule) -> None:
20
+ """Normalize an output batch.
21
+
22
+ Args:
23
+ batch (dict[str, torch.Tensor]): Output batch.
24
+ pl_module (AnomalyModule): AnomalyModule instance.
25
+
26
+ Returns:
27
+ dict[str, torch.Tensor]: Normalized batch.
28
+ """
29
+ raise NotImplementedError
anomalib/callbacks/normalization/min_max_normalization.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomaly Score Normalization Callback that uses min-max normalization."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from typing import Any
8
+
9
+ import torch
10
+ from lightning.pytorch import Trainer
11
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
12
+
13
+ from anomalib.metrics import MinMax
14
+ from anomalib.models.components import AnomalyModule
15
+ from anomalib.utils.normalization.min_max import normalize
16
+
17
+ from .base import NormalizationCallback
18
+
19
+
20
+ class _MinMaxNormalizationCallback(NormalizationCallback):
21
+ """Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization.
22
+
23
+ Note: This callback is set within the Engine.
24
+ """
25
+
26
+ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
27
+ """Add min_max metrics to normalization metrics."""
28
+ del trainer, stage # These variables are not used.
29
+
30
+ if not hasattr(pl_module, "normalization_metrics"):
31
+ pl_module.normalization_metrics = MinMax().cpu()
32
+ elif not isinstance(pl_module.normalization_metrics, MinMax):
33
+ msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
34
+ raise AttributeError(
35
+ msg,
36
+ )
37
+
38
+ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
39
+ """Call when the test begins."""
40
+ del trainer # `trainer` variable is not used.
41
+
42
+ for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics):
43
+ if metric is not None:
44
+ metric.set_threshold(0.5)
45
+
46
+ def on_validation_batch_end(
47
+ self,
48
+ trainer: Trainer,
49
+ pl_module: AnomalyModule,
50
+ outputs: STEP_OUTPUT,
51
+ batch: Any, # noqa: ANN401
52
+ batch_idx: int,
53
+ dataloader_idx: int = 0,
54
+ ) -> None:
55
+ """Call when the validation batch ends, update the min and max observed values."""
56
+ del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
57
+
58
+ if "anomaly_maps" in outputs:
59
+ pl_module.normalization_metrics(outputs["anomaly_maps"])
60
+ elif "box_scores" in outputs:
61
+ pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
62
+ elif "pred_scores" in outputs:
63
+ pl_module.normalization_metrics(outputs["pred_scores"])
64
+ else:
65
+ msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
66
+ raise ValueError(msg)
67
+
68
+ def on_test_batch_end(
69
+ self,
70
+ trainer: Trainer,
71
+ pl_module: AnomalyModule,
72
+ outputs: STEP_OUTPUT | None,
73
+ batch: Any, # noqa: ANN401
74
+ batch_idx: int,
75
+ dataloader_idx: int = 0,
76
+ ) -> None:
77
+ """Call when the test batch ends, normalizes the predicted scores and anomaly maps."""
78
+ del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
79
+
80
+ self._normalize_batch(outputs, pl_module)
81
+
82
+ def on_predict_batch_end(
83
+ self,
84
+ trainer: Trainer,
85
+ pl_module: AnomalyModule,
86
+ outputs: Any, # noqa: ANN401
87
+ batch: Any, # noqa: ANN401
88
+ batch_idx: int,
89
+ dataloader_idx: int = 0,
90
+ ) -> None:
91
+ """Call when the predict batch ends, normalizes the predicted scores and anomaly maps."""
92
+ del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
93
+
94
+ self._normalize_batch(outputs, pl_module)
95
+
96
+ @staticmethod
97
+ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401
98
+ """Normalize a batch of predictions."""
99
+ image_threshold = pl_module.image_threshold.value.cpu()
100
+ pixel_threshold = pl_module.pixel_threshold.value.cpu()
101
+ stats = pl_module.normalization_metrics.cpu()
102
+ if "pred_scores" in outputs:
103
+ outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
104
+ if "anomaly_maps" in outputs:
105
+ outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
106
+ if "box_scores" in outputs:
107
+ outputs["box_scores"] = [
108
+ normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
109
+ ]
anomalib/callbacks/normalization/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalization callback utils."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import importlib
7
+
8
+ from lightning.pytorch import Callback
9
+ from omegaconf import DictConfig
10
+
11
+ from anomalib.utils.normalization import NormalizationMethod
12
+ from anomalib.utils.types import NORMALIZATION
13
+
14
+ from .min_max_normalization import _MinMaxNormalizationCallback
15
+
16
+
17
+ def get_normalization_callback(
18
+ normalization_method: NORMALIZATION = NormalizationMethod.MIN_MAX,
19
+ ) -> Callback | None:
20
+ """Return normalization object.
21
+
22
+ normalization_method is an instance of ``Callback``, it is returned as is.
23
+
24
+ if normalization_method is of type ``NormalizationMethod``, then a new class is created based on the type of
25
+ normalization_method.
26
+
27
+ Otherwise it expects a dictionary containing class_path and init_args.
28
+ normalization_method:
29
+ class_path: MinMaxNormalizer
30
+ init_args:
31
+ -
32
+ -
33
+
34
+ Example:
35
+ >>> normalizer = get_normalization_callback(NormalizationMethod.MIN_MAX)
36
+ or
37
+ >>> normalizer = get_normalization_callback("min_max")
38
+ or
39
+ >>> normalizer = get_normalization_callback({"class_path": "MinMaxNormalizationCallback", "init_args": {}})
40
+ or
41
+ >>> normalizer = get_normalization_callback(MinMaxNormalizationCallback())
42
+ """
43
+ normalizer: Callback | None
44
+ if isinstance(normalization_method, NormalizationMethod | str):
45
+ normalizer = _get_normalizer_from_method(NormalizationMethod(normalization_method))
46
+ elif isinstance(normalization_method, Callback):
47
+ normalizer = normalization_method
48
+ elif isinstance(normalization_method, DictConfig):
49
+ normalizer = _parse_normalizer_config(normalization_method)
50
+ else:
51
+ msg = f"Unknown normalizer type {normalization_method}"
52
+ raise TypeError(msg)
53
+ return normalizer
54
+
55
+
56
+ def _get_normalizer_from_method(normalization_method: NormalizationMethod | str) -> Callback | None:
57
+ if normalization_method == NormalizationMethod.NONE:
58
+ normalizer = None
59
+ elif normalization_method == NormalizationMethod.MIN_MAX:
60
+ normalizer = _MinMaxNormalizationCallback()
61
+ else:
62
+ msg = f"Unknown normalization method {normalization_method}"
63
+ raise ValueError(msg)
64
+ return normalizer
65
+
66
+
67
+ def _parse_normalizer_config(normalization_method: DictConfig) -> Callback:
68
+ class_path = normalization_method.class_path
69
+ init_args = normalization_method.init_args
70
+
71
+ if len(class_path.split(".")) == 1:
72
+ module_path = "anomalib.utils.callbacks.normalization"
73
+ else:
74
+ module_path = ".".join(class_path.split(".")[:-1])
75
+ class_path = class_path.split(".")[-1]
76
+ module = importlib.import_module(module_path)
77
+ class_ = getattr(module, class_path)
78
+ return class_(**init_args)
anomalib/callbacks/post_processor.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callback that attaches necessary pre/post-processing to the model."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from typing import Any
8
+
9
+ import torch
10
+ from lightning import Callback
11
+ from lightning.pytorch import Trainer
12
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
13
+
14
+ from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
15
+ from anomalib.models import AnomalyModule
16
+
17
+
18
+ class _PostProcessorCallback(Callback):
19
+ """Applies post-processing to the model outputs.
20
+
21
+ Note: This callback is set within the Engine.
22
+ """
23
+
24
+ def __init__(self) -> None:
25
+ super().__init__()
26
+
27
+ def on_validation_batch_end(
28
+ self,
29
+ trainer: Trainer,
30
+ pl_module: AnomalyModule,
31
+ outputs: STEP_OUTPUT | None,
32
+ batch: Any, # noqa: ANN401
33
+ batch_idx: int,
34
+ dataloader_idx: int = 0,
35
+ ) -> None:
36
+ del batch, batch_idx, dataloader_idx # Unused arguments.
37
+
38
+ if outputs is not None:
39
+ self.post_process(trainer, pl_module, outputs)
40
+
41
+ def on_test_batch_end(
42
+ self,
43
+ trainer: Trainer,
44
+ pl_module: AnomalyModule,
45
+ outputs: STEP_OUTPUT | None,
46
+ batch: Any, # noqa: ANN401
47
+ batch_idx: int,
48
+ dataloader_idx: int = 0,
49
+ ) -> None:
50
+ del batch, batch_idx, dataloader_idx # Unused arguments.
51
+
52
+ if outputs is not None:
53
+ self.post_process(trainer, pl_module, outputs)
54
+
55
+ def on_predict_batch_end(
56
+ self,
57
+ trainer: Trainer,
58
+ pl_module: AnomalyModule,
59
+ outputs: Any, # noqa: ANN401
60
+ batch: Any, # noqa: ANN401
61
+ batch_idx: int,
62
+ dataloader_idx: int = 0,
63
+ ) -> None:
64
+ del batch, batch_idx, dataloader_idx # Unused arguments.
65
+
66
+ if outputs is not None:
67
+ self.post_process(trainer, pl_module, outputs)
68
+
69
+ def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
70
+ if isinstance(outputs, dict):
71
+ self._post_process(outputs)
72
+ if trainer.predicting or trainer.testing:
73
+ self._compute_scores_and_labels(pl_module, outputs)
74
+
75
+ @staticmethod
76
+ def _compute_scores_and_labels(
77
+ pl_module: AnomalyModule,
78
+ outputs: dict[str, Any],
79
+ ) -> None:
80
+ if "pred_scores" in outputs:
81
+ outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value
82
+ if "anomaly_maps" in outputs:
83
+ outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value
84
+ if "pred_boxes" not in outputs:
85
+ outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes(
86
+ outputs["pred_masks"],
87
+ outputs["anomaly_maps"],
88
+ )
89
+ outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
90
+ # apply thresholding to boxes
91
+ if "box_scores" in outputs and "box_labels" not in outputs:
92
+ # apply threshold to assign normal/anomalous label to boxes
93
+ is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]]
94
+ outputs["box_labels"] = [labels.int() for labels in is_anomalous]
95
+
96
+ @staticmethod
97
+ def _post_process(outputs: STEP_OUTPUT) -> None:
98
+ """Compute labels based on model predictions."""
99
+ if isinstance(outputs, dict):
100
+ if "pred_scores" not in outputs and "anomaly_maps" in outputs:
101
+ # infer image scores from anomaly maps
102
+ outputs["pred_scores"] = (
103
+ outputs["anomaly_maps"] # noqa: PD011
104
+ .reshape(outputs["anomaly_maps"].shape[0], -1)
105
+ .max(dim=1)
106
+ .values
107
+ )
108
+ elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs:
109
+ # infer image score from bbox confidence scores
110
+ outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float()
111
+ for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)):
112
+ if boxes.numel():
113
+ outputs["pred_scores"][idx] = scores.max().item()
114
+
115
+ if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
116
+ # create anomaly maps from bbox predictions for thresholding and evaluation
117
+ image_size: tuple[int, int] = outputs["image"].shape[-2:]
118
+ pred_boxes: torch.Tensor = outputs["pred_boxes"]
119
+ box_scores: torch.Tensor = outputs["box_scores"]
120
+
121
+ outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size)
122
+
123
+ if "boxes" in outputs:
124
+ true_boxes: list[torch.Tensor] = outputs["boxes"]
125
+ outputs["mask"] = boxes_to_masks(true_boxes, image_size)
anomalib/callbacks/thresholding.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thresholding callback."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import importlib
7
+ from typing import Any
8
+
9
+ import torch
10
+ from lightning.pytorch import Callback, Trainer
11
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
12
+ from omegaconf import DictConfig, ListConfig
13
+
14
+ from anomalib.metrics.threshold import BaseThreshold
15
+ from anomalib.models import AnomalyModule
16
+ from anomalib.utils.types import THRESHOLD
17
+
18
+
19
+ class _ThresholdCallback(Callback):
20
+ """Setup/apply thresholding.
21
+
22
+ Note: This callback is set within the Engine.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ threshold: THRESHOLD = "F1AdaptiveThreshold",
28
+ ) -> None:
29
+ super().__init__()
30
+ self._initialize_thresholds(threshold)
31
+ self.image_threshold: BaseThreshold
32
+ self.pixel_threshold: BaseThreshold
33
+
34
+ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str) -> None:
35
+ del trainer, stage # Unused arguments.
36
+ if not hasattr(pl_module, "image_threshold"):
37
+ pl_module.image_threshold = self.image_threshold
38
+ if not hasattr(pl_module, "pixel_threshold"):
39
+ pl_module.pixel_threshold = self.pixel_threshold
40
+
41
+ def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
42
+ del trainer # Unused argument.
43
+ self._reset(pl_module)
44
+
45
+ def on_validation_batch_end(
46
+ self,
47
+ trainer: Trainer,
48
+ pl_module: AnomalyModule,
49
+ outputs: STEP_OUTPUT | None,
50
+ batch: Any, # noqa: ANN401
51
+ batch_idx: int,
52
+ dataloader_idx: int = 0,
53
+ ) -> None:
54
+ del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
55
+ if outputs is not None:
56
+ self._outputs_to_cpu(outputs)
57
+ self._update(pl_module, outputs)
58
+
59
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
60
+ del trainer # Unused argument.
61
+ self._compute(pl_module)
62
+
63
+ def _initialize_thresholds(
64
+ self,
65
+ threshold: THRESHOLD,
66
+ ) -> None:
67
+ """Initialize ``self.image_threshold`` and ``self.pixel_threshold``.
68
+
69
+ Args:
70
+ threshold (THRESHOLD):
71
+ Threshold configuration
72
+
73
+ Example:
74
+ >>> _initialize_thresholds(F1AdaptiveThreshold())
75
+ or
76
+ >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
77
+ or configuration
78
+
79
+ For more details on configuration see :fun:`_load_from_config`
80
+
81
+ Raises:
82
+ ValueError: Unknown threshold class or incorrect configuration
83
+ """
84
+ # TODO(djdameln): Add tests for each case
85
+ # CVS-122661
86
+ # When only a single threshold class is passed.
87
+ # This initializes image and pixel thresholds with the same class
88
+ # >>> _initialize_thresholds(F1AdaptiveThreshold())
89
+ if isinstance(threshold, BaseThreshold):
90
+ self.image_threshold = threshold
91
+ self.pixel_threshold = threshold.clone()
92
+
93
+ # When a tuple of threshold classes are passed
94
+ # >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
95
+ elif isinstance(threshold, tuple) and isinstance(threshold[0], BaseThreshold):
96
+ self.image_threshold = threshold[0]
97
+ self.pixel_threshold = threshold[1]
98
+ # When the passed threshold is not an instance of a Threshold class.
99
+ elif isinstance(threshold, str | DictConfig | ListConfig | list):
100
+ self._load_from_config(threshold)
101
+ else:
102
+ msg = f"Invalid threshold type {type(threshold)}"
103
+ raise TypeError(msg)
104
+
105
+ def _load_from_config(self, threshold: DictConfig | str | ListConfig | list[dict[str, str | float]]) -> None:
106
+ """Load the thresholding class based on the config.
107
+
108
+ Example:
109
+ threshold: F1AdaptiveThreshold
110
+ or
111
+ threshold:
112
+ class_path: F1AdaptiveThreshold
113
+ init_args:
114
+ -
115
+ or
116
+ threshold:
117
+ - F1AdaptiveThreshold
118
+ - F1AdaptiveThreshold
119
+ or
120
+ threshold:
121
+ - class_path: F1AdaptiveThreshold
122
+ init_args:
123
+ -
124
+ - class_path: F1AdaptiveThreshold
125
+ """
126
+ if isinstance(threshold, str | DictConfig):
127
+ self.image_threshold = self._get_threshold_from_config(threshold)
128
+ self.pixel_threshold = self.image_threshold.clone()
129
+ elif isinstance(threshold, ListConfig | list):
130
+ self.image_threshold = self._get_threshold_from_config(threshold[0])
131
+ self.pixel_threshold = self._get_threshold_from_config(threshold[1])
132
+ else:
133
+ msg = f"Invalid threshold config {threshold}"
134
+ raise TypeError(msg)
135
+
136
+ def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> BaseThreshold:
137
+ """Return the instantiated threshold object.
138
+
139
+ Example:
140
+ >>> _get_threshold_from_config(F1AdaptiveThreshold)
141
+ or
142
+ >>> config = DictConfig({
143
+ ... "class_path": "ManualThreshold",
144
+ ... "init_args": {"default_value": 0.7}
145
+ ... })
146
+ >>> __get_threshold_from_config(config)
147
+ or
148
+ >>> config = DictConfig({
149
+ ... "class_path": "anomalib.metrics.threshold.F1AdaptiveThreshold"
150
+ ... })
151
+ >>> __get_threshold_from_config(config)
152
+
153
+ Returns:
154
+ (BaseThreshold): Instance of threshold object.
155
+ """
156
+ if isinstance(threshold, str):
157
+ threshold = DictConfig({"class_path": threshold})
158
+
159
+ class_path = threshold["class_path"]
160
+ init_args = threshold.get("init_args", {})
161
+
162
+ if len(class_path.split(".")) == 1:
163
+ module_path = "anomalib.metrics.threshold"
164
+
165
+ else:
166
+ module_path = ".".join(class_path.split(".")[:-1])
167
+ class_path = class_path.split(".")[-1]
168
+
169
+ module = importlib.import_module(module_path)
170
+ class_ = getattr(module, class_path)
171
+ return class_(**init_args)
172
+
173
+ def _reset(self, pl_module: AnomalyModule) -> None:
174
+ pl_module.image_threshold.reset()
175
+ pl_module.pixel_threshold.reset()
176
+
177
+ def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
178
+ if isinstance(output, dict):
179
+ for key, value in output.items():
180
+ output[key] = self._outputs_to_cpu(value)
181
+ elif isinstance(output, torch.Tensor):
182
+ output = output.cpu()
183
+ return output
184
+
185
+ def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
186
+ pl_module.image_threshold.cpu()
187
+ pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int())
188
+ if "mask" in outputs and "anomaly_maps" in outputs:
189
+ pl_module.pixel_threshold.cpu()
190
+ pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int())
191
+
192
+ def _compute(self, pl_module: AnomalyModule) -> None:
193
+ pl_module.image_threshold.compute()
194
+ if pl_module.pixel_threshold._update_called: # noqa: SLF001
195
+ pl_module.pixel_threshold.compute()
196
+ else:
197
+ pl_module.pixel_threshold.value = pl_module.image_threshold.value
anomalib/callbacks/tiler_configuration.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiler Callback."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from collections.abc import Sequence
8
+
9
+ import lightning.pytorch as pl
10
+ from lightning.pytorch.callbacks import Callback
11
+
12
+ from anomalib.data.utils.tiler import ImageUpscaleMode, Tiler
13
+ from anomalib.models.components import AnomalyModule
14
+
15
+ __all__ = ["TilerConfigurationCallback"]
16
+
17
+
18
+ class TilerConfigurationCallback(Callback):
19
+ """Tiler Configuration Callback."""
20
+
21
+ def __init__(
22
+ self,
23
+ enable: bool = False,
24
+ tile_size: int | Sequence = 256,
25
+ stride: int | Sequence | None = None,
26
+ remove_border_count: int = 0,
27
+ mode: ImageUpscaleMode = ImageUpscaleMode.PADDING,
28
+ ) -> None:
29
+ """Set tiling configuration from the command line.
30
+
31
+ Args:
32
+ enable (bool): Boolean to enable tiling operation.
33
+ Defaults to False.
34
+ tile_size ([int | Sequence]): Tile size.
35
+ Defaults to 256.
36
+ stride ([int | Sequence]): Stride to move tiles on the image.
37
+ remove_border_count (int, optional): Number of pixels to remove from the image before
38
+ tiling. Defaults to 0.
39
+ mode (str, optional): Up-scaling mode when untiling overlapping tiles.
40
+ Defaults to "padding".
41
+ tile_count (SupportsIndex, optional): Number of random tiles to sample from the image.
42
+ Defaults to 4.
43
+ """
44
+ self.enable = enable
45
+ self.tile_size = tile_size
46
+ self.stride = stride
47
+ self.remove_border_count = remove_border_count
48
+ self.mode = mode
49
+
50
+ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None:
51
+ """Set Tiler object within Anomalib Model.
52
+
53
+ Args:
54
+ trainer (pl.Trainer): PyTorch Lightning Trainer
55
+ pl_module (pl.LightningModule): Anomalib Model that inherits pl LightningModule.
56
+ stage (str | None, optional): fit, validate, test or predict. Defaults to None.
57
+
58
+ Raises:
59
+ ValueError: When Anomalib Model doesn't contain ``Tiler`` object, it means the model
60
+ doesn not support tiling operation.
61
+ """
62
+ del trainer, stage # These variables are not used.
63
+
64
+ if self.enable:
65
+ if isinstance(pl_module, AnomalyModule) and hasattr(pl_module.model, "tiler"):
66
+ pl_module.model.tiler = Tiler(
67
+ tile_size=self.tile_size,
68
+ stride=self.stride,
69
+ remove_border_count=self.remove_border_count,
70
+ mode=self.mode,
71
+ )
72
+ else:
73
+ msg = "Model does not support tiling."
74
+ raise ValueError(msg)
anomalib/callbacks/timer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callback to measure training and testing time of a PyTorch Lightning module."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import logging
7
+ import time
8
+
9
+ import torch
10
+ from lightning.pytorch import Callback, LightningModule, Trainer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TimerCallback(Callback):
16
+ """Callback that measures the training and testing time of a PyTorch Lightning module.
17
+
18
+ Examples:
19
+ >>> from anomalib.callbacks import TimerCallback
20
+ >>> from anomalib.engine import Engine
21
+ ...
22
+ >>> callbacks = [TimerCallback()]
23
+ >>> engine = Engine(callbacks=callbacks)
24
+ """
25
+
26
+ def __init__(self) -> None:
27
+ self.start: float
28
+ self.num_images: int = 0
29
+
30
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
31
+ """Call when fit begins.
32
+
33
+ Sets the start time to the time training started.
34
+
35
+ Args:
36
+ trainer (Trainer): PyTorch Lightning trainer.
37
+ pl_module (LightningModule): Current training module.
38
+
39
+ Returns:
40
+ None
41
+ """
42
+ del trainer, pl_module # These variables are not used.
43
+
44
+ self.start = time.time()
45
+
46
+ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
47
+ """Call when fit ends.
48
+
49
+ Prints the time taken for training.
50
+
51
+ Args:
52
+ trainer (Trainer): PyTorch Lightning trainer.
53
+ pl_module (LightningModule): Current training module.
54
+
55
+ Returns:
56
+ None
57
+ """
58
+ del trainer, pl_module # Unused arguments.
59
+ logger.info("Training took %5.2f seconds", (time.time() - self.start))
60
+
61
+ def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
62
+ """Call when the test begins.
63
+
64
+ Sets the start time to the time testing started.
65
+ Goes over all the test dataloaders and adds the number of images in each.
66
+
67
+ Args:
68
+ trainer (Trainer): PyTorch Lightning trainer.
69
+ pl_module (LightningModule): Current training module.
70
+
71
+ Returns:
72
+ None
73
+ """
74
+ del pl_module # Unused argument.
75
+
76
+ self.start = time.time()
77
+ self.num_images = 0
78
+
79
+ if trainer.test_dataloaders is not None: # Check to placate Mypy.
80
+ if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader):
81
+ self.num_images += len(trainer.test_dataloaders.dataset)
82
+ else:
83
+ for dataloader in trainer.test_dataloaders:
84
+ self.num_images += len(dataloader.dataset)
85
+
86
+ def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
87
+ """Call when the test ends.
88
+
89
+ Prints the time taken for testing and the throughput in frames per second.
90
+
91
+ Args:
92
+ trainer (Trainer): PyTorch Lightning trainer.
93
+ pl_module (LightningModule): Current training module.
94
+
95
+ Returns:
96
+ None
97
+ """
98
+ del pl_module # Unused argument.
99
+
100
+ testing_time = time.time() - self.start
101
+ output = f"Testing took {testing_time} seconds\nThroughput "
102
+ if trainer.test_dataloaders is not None:
103
+ if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader):
104
+ test_data_loader = trainer.test_dataloaders
105
+ else:
106
+ test_data_loader = trainer.test_dataloaders[0]
107
+ output += f"(batch_size={test_data_loader.batch_size})"
108
+ output += f" : {self.num_images/testing_time} FPS"
109
+ logger.info(output)
anomalib/callbacks/visualizer.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualizer Callback.
2
+
3
+ This is assigned by Anomalib Engine internally.
4
+ """
5
+
6
+ # Copyright (C) 2024 Intel Corporation
7
+ # SPDX-License-Identifier: Apache-2.0
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any, cast
12
+
13
+ from lightning.pytorch import Callback, Trainer
14
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
15
+
16
+ from anomalib.data.utils.image import save_image, show_image
17
+ from anomalib.loggers import AnomalibWandbLogger
18
+ from anomalib.loggers.base import ImageLoggerBase
19
+ from anomalib.models import AnomalyModule
20
+ from anomalib.utils.visualization import (
21
+ BaseVisualizer,
22
+ GeneratorResult,
23
+ VisualizationStep,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class _VisualizationCallback(Callback):
30
+ """Callback for visualization that is used internally by the Engine.
31
+
32
+ Args:
33
+ visualizers (BaseVisualizer | list[BaseVisualizer]):
34
+ Visualizer objects that are used for computing the visualizations. Defaults to None.
35
+ save (bool, optional): Save the image. Defaults to False.
36
+ root (Path | None, optional): The path to save the images. Defaults to None.
37
+ log (bool, optional): Log the images into the loggers. Defaults to False.
38
+ show (bool, optional): Show the images. Defaults to False.
39
+
40
+ Example:
41
+ >>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
42
+ >>> visualization_callback = _VisualizationCallback(
43
+ ... visualizers=visualizers,
44
+ ... save=True,
45
+ ... root="results/images"
46
+ ... )
47
+
48
+ CLI
49
+ $ anomalib train --model Padim --data MVTec \
50
+ --visualization.visualizers ImageVisualizer \
51
+ --visualization.visualizers+=MetricsVisualizer
52
+ or
53
+ $ anomalib train --model Padim --data MVTec \
54
+ --visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'
55
+
56
+ Raises:
57
+ ValueError: Incase `root` is None and `save` is True.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ visualizers: BaseVisualizer | list[BaseVisualizer],
63
+ save: bool = False,
64
+ root: Path | None = None,
65
+ log: bool = False,
66
+ show: bool = False,
67
+ ) -> None:
68
+ self.save = save
69
+ if save and root is None:
70
+ msg = "`root` must be provided if save is True"
71
+ raise ValueError(msg)
72
+ self.root: Path = root if root is not None else Path() # need this check for mypy
73
+ self.log = log
74
+ self.show = show
75
+ self.generators = visualizers if isinstance(visualizers, list) else [visualizers]
76
+
77
+ def on_test_batch_end(
78
+ self,
79
+ trainer: Trainer,
80
+ pl_module: AnomalyModule,
81
+ outputs: STEP_OUTPUT | None,
82
+ batch: Any, # noqa: ANN401
83
+ batch_idx: int,
84
+ dataloader_idx: int = 0,
85
+ ) -> None:
86
+ for generator in self.generators:
87
+ if generator.visualize_on == VisualizationStep.BATCH:
88
+ for result in generator(
89
+ trainer=trainer,
90
+ pl_module=pl_module,
91
+ outputs=outputs,
92
+ batch=batch,
93
+ batch_idx=batch_idx,
94
+ dataloader_idx=dataloader_idx,
95
+ ):
96
+ if self.save:
97
+ if result.file_name is None:
98
+ msg = "``save`` is set to ``True`` but file name is ``None``"
99
+ raise ValueError(msg)
100
+
101
+ # Get the filename to save the image.
102
+ # Filename is split based on the datamodule name and category.
103
+ # For example, if the filename is `MVTec/bottle/000.png`, then the
104
+ # filename is split based on `MVTec/bottle` and `000.png` is saved.
105
+ if trainer.datamodule is not None:
106
+ filename = str(result.file_name).split(
107
+ sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}",
108
+ )[-1]
109
+ else:
110
+ filename = Path(result.file_name).name
111
+ save_image(image=result.image, root=self.root, filename=filename)
112
+ if self.show:
113
+ show_image(image=result.image, title=str(result.file_name))
114
+ if self.log:
115
+ self._add_to_logger(result, pl_module, trainer)
116
+
117
+ def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
118
+ for generator in self.generators:
119
+ if generator.visualize_on == VisualizationStep.STAGE_END:
120
+ for result in generator(trainer=trainer, pl_module=pl_module):
121
+ if self.save:
122
+ if result.file_name is None:
123
+ msg = "``save`` is set to ``True`` but file name is ``None``"
124
+ raise ValueError(msg)
125
+ save_image(image=result.image, root=self.root, filename=result.file_name)
126
+ if self.show:
127
+ show_image(image=result.image, title=str(result.file_name))
128
+ if self.log:
129
+ self._add_to_logger(result, pl_module, trainer)
130
+
131
+ for logger in trainer.loggers:
132
+ if isinstance(logger, AnomalibWandbLogger):
133
+ logger.save()
134
+
135
+ def on_predict_batch_end(
136
+ self,
137
+ trainer: Trainer,
138
+ pl_module: AnomalyModule,
139
+ outputs: STEP_OUTPUT | None,
140
+ batch: Any, # noqa: ANN401
141
+ batch_idx: int,
142
+ dataloader_idx: int = 0,
143
+ ) -> None:
144
+ return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
145
+
146
+ def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
147
+ return self.on_test_end(trainer, pl_module)
148
+
149
+ def _add_to_logger(
150
+ self,
151
+ result: GeneratorResult,
152
+ module: AnomalyModule,
153
+ trainer: Trainer,
154
+ ) -> None:
155
+ """Add image to logger.
156
+
157
+ Args:
158
+ result (GeneratorResult): Output from the generators.
159
+ module (AnomalyModule): LightningModule from which the global step is extracted.
160
+ trainer (Trainer): Trainer object.
161
+ """
162
+ # Store names of logger and the logger in a dict
163
+ available_loggers = {
164
+ type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
165
+ for logger in trainer.loggers
166
+ }
167
+ # save image to respective logger
168
+ if result.file_name is None:
169
+ msg = "File name is None"
170
+ raise ValueError(msg)
171
+ filename = result.file_name
172
+ image = result.image
173
+ for log_to in available_loggers:
174
+ # check if logger object is same as the requested object
175
+ if isinstance(available_loggers[log_to], ImageLoggerBase):
176
+ logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
177
+ _name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
178
+ logger.add_image(
179
+ image=image,
180
+ name=_name,
181
+ global_step=module.global_step,
182
+ )
anomalib/cli/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Anomalib CLI."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from .cli import AnomalibCLI
7
+
8
+ __all__ = ["AnomalibCLI"]
anomalib/cli/cli.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib CLI."""
2
+
3
+ # Copyright (C) 2023-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import logging
7
+ from collections.abc import Callable, Sequence
8
+ from functools import partial
9
+ from pathlib import Path
10
+ from types import MethodType
11
+ from typing import Any
12
+
13
+ from jsonargparse import ActionConfigFile, ArgumentParser, Namespace
14
+ from jsonargparse._actions import _ActionSubCommands
15
+ from rich import traceback
16
+
17
+ from anomalib import TaskType, __version__
18
+ from anomalib.cli.utils.help_formatter import CustomHelpFormatter, get_short_docstring
19
+ from anomalib.cli.utils.openvino import add_openvino_export_arguments
20
+ from anomalib.loggers import configure_logger
21
+
22
+ traceback.install()
23
+ logger = logging.getLogger("anomalib.cli")
24
+
25
+ _LIGHTNING_AVAILABLE = True
26
+ try:
27
+ from lightning.pytorch import Trainer
28
+ from torch.utils.data import DataLoader, Dataset
29
+
30
+ from anomalib.data import AnomalibDataModule
31
+ from anomalib.engine import Engine
32
+ from anomalib.metrics.threshold import BaseThreshold
33
+ from anomalib.models import AnomalyModule
34
+ from anomalib.utils.config import update_config
35
+
36
+ except ImportError:
37
+ _LIGHTNING_AVAILABLE = False
38
+
39
+
40
+ class AnomalibCLI:
41
+ """Implementation of a fully configurable CLI tool for anomalib.
42
+
43
+ The advantage of this tool is its flexibility to configure the pipeline
44
+ from both the CLI and a configuration file (.yaml or .json). It is even
45
+ possible to use both the CLI and a configuration file simultaneously.
46
+ For more details, the reader could refer to PyTorch Lightning CLI
47
+ documentation.
48
+
49
+ ``save_config_kwargs`` is set to ``overwrite=True`` so that the
50
+ ``SaveConfigCallback`` overwrites the config if it already exists.
51
+ """
52
+
53
+ def __init__(self, args: Sequence[str] | None = None) -> None:
54
+ self.parser = self.init_parser()
55
+ self.subcommand_parsers: dict[str, ArgumentParser] = {}
56
+ self.subcommand_method_arguments: dict[str, list[str]] = {}
57
+ self.add_subcommands()
58
+ self.config = self.parser.parse_args(args=args)
59
+ self.subcommand = self.config["subcommand"]
60
+ if _LIGHTNING_AVAILABLE:
61
+ self.before_instantiate_classes()
62
+ self.instantiate_classes()
63
+ self._run_subcommand()
64
+
65
+ def init_parser(self, **kwargs) -> ArgumentParser:
66
+ """Method that instantiates the argument parser."""
67
+ kwargs.setdefault("dump_header", [f"anomalib=={__version__}"])
68
+ parser = ArgumentParser(formatter_class=CustomHelpFormatter, **kwargs)
69
+ parser.add_argument(
70
+ "-c",
71
+ "--config",
72
+ action=ActionConfigFile,
73
+ help="Path to a configuration file in json or yaml format.",
74
+ )
75
+ return parser
76
+
77
+ @staticmethod
78
+ def subcommands() -> dict[str, set[str]]:
79
+ """Skip predict subcommand as it is added later."""
80
+ return {
81
+ "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
82
+ "validate": {"model", "dataloaders", "datamodule"},
83
+ "test": {"model", "dataloaders", "datamodule"},
84
+ }
85
+
86
+ @staticmethod
87
+ def anomalib_subcommands() -> dict[str, dict[str, str]]:
88
+ """Return a dictionary of subcommands and their description."""
89
+ return {
90
+ "train": {"description": "Fit the model and then call test on the trained model."},
91
+ "predict": {"description": "Run inference on a model."},
92
+ "export": {"description": "Export the model to ONNX or OpenVINO format."},
93
+ }
94
+
95
+ def add_subcommands(self, **kwargs) -> None:
96
+ """Initialize base subcommands and add anomalib specific on top of it."""
97
+ parser_subcommands = self.parser.add_subcommands()
98
+
99
+ # Extra subcommand: install
100
+ self._set_install_subcommand(parser_subcommands)
101
+
102
+ if not _LIGHTNING_AVAILABLE:
103
+ # If environment is not configured to use pl, do not add a subcommand for Engine.
104
+ return
105
+
106
+ # Add Trainer subcommands
107
+ for subcommand in self.subcommands():
108
+ sub_parser = self.init_parser(**kwargs)
109
+
110
+ fn = getattr(Trainer, subcommand)
111
+ # extract the first line description in the docstring for the subcommand help message
112
+ description = get_short_docstring(fn)
113
+ subparser_kwargs = kwargs.get(subcommand, {})
114
+ subparser_kwargs.setdefault("description", description)
115
+
116
+ self.subcommand_parsers[subcommand] = sub_parser
117
+ parser_subcommands.add_subcommand(subcommand, sub_parser, help=description)
118
+ self.add_trainer_arguments(sub_parser, subcommand)
119
+
120
+ # Add anomalib subcommands
121
+ for subcommand in self.anomalib_subcommands():
122
+ sub_parser = self.init_parser(**kwargs)
123
+
124
+ self.subcommand_parsers[subcommand] = sub_parser
125
+ parser_subcommands.add_subcommand(
126
+ subcommand,
127
+ sub_parser,
128
+ help=self.anomalib_subcommands()[subcommand]["description"],
129
+ )
130
+ # add arguments to subcommand
131
+ getattr(self, f"add_{subcommand}_arguments")(sub_parser)
132
+
133
+ def add_arguments_to_parser(self, parser: ArgumentParser) -> None:
134
+ """Extend trainer's arguments to add engine arguments.
135
+
136
+ .. note::
137
+ Since ``Engine`` parameters are manually added, any change to the
138
+ ``Engine`` class should be reflected manually.
139
+ """
140
+ from anomalib.callbacks.normalization import get_normalization_callback
141
+
142
+ parser.add_function_arguments(get_normalization_callback, "normalization")
143
+ parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION)
144
+ parser.add_argument(
145
+ "--metrics.image",
146
+ type=list[str] | str | dict[str, dict[str, Any]] | None,
147
+ default=["F1Score", "AUROC"],
148
+ )
149
+ parser.add_argument(
150
+ "--metrics.pixel",
151
+ type=list[str] | str | dict[str, dict[str, Any]] | None,
152
+ default=None,
153
+ required=False,
154
+ )
155
+ parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
156
+ parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
157
+ if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"):
158
+ parser.link_arguments("task", "data.init_args.task")
159
+ parser.add_argument(
160
+ "--default_root_dir",
161
+ type=Path,
162
+ help="Path to save the results.",
163
+ default=Path("./results"),
164
+ )
165
+ parser.link_arguments("default_root_dir", "trainer.default_root_dir")
166
+ # TODO(ashwinvaidya17): Tiling should also be a category of its own
167
+ # CVS-122659
168
+
169
+ def add_trainer_arguments(self, parser: ArgumentParser, subcommand: str) -> None:
170
+ """Add train arguments to the parser."""
171
+ self._add_default_arguments_to_parser(parser)
172
+ self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True)
173
+ parser.add_subclass_arguments(
174
+ AnomalyModule,
175
+ "model",
176
+ fail_untyped=False,
177
+ required=True,
178
+ )
179
+ parser.add_subclass_arguments(AnomalibDataModule, "data")
180
+ self.add_arguments_to_parser(parser)
181
+ skip: set[str | int] = set(self.subcommands()[subcommand])
182
+ added = parser.add_method_arguments(
183
+ Trainer,
184
+ subcommand,
185
+ skip=skip,
186
+ )
187
+ self.subcommand_method_arguments[subcommand] = added
188
+
189
+ def add_train_arguments(self, parser: ArgumentParser) -> None:
190
+ """Add train arguments to the parser."""
191
+ self._add_default_arguments_to_parser(parser)
192
+ self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True)
193
+ parser.add_subclass_arguments(
194
+ AnomalyModule,
195
+ "model",
196
+ fail_untyped=False,
197
+ required=True,
198
+ )
199
+ parser.add_subclass_arguments(AnomalibDataModule, "data")
200
+ self.add_arguments_to_parser(parser)
201
+ added = parser.add_method_arguments(
202
+ Engine,
203
+ "train",
204
+ skip={"model", "datamodule", "val_dataloaders", "test_dataloaders", "train_dataloaders"},
205
+ )
206
+ self.subcommand_method_arguments["train"] = added
207
+
208
+ def add_predict_arguments(self, parser: ArgumentParser) -> None:
209
+ """Add predict arguments to the parser."""
210
+ self._add_default_arguments_to_parser(parser)
211
+ self._add_trainer_arguments_to_parser(parser)
212
+ parser.add_subclass_arguments(
213
+ AnomalyModule,
214
+ "model",
215
+ fail_untyped=False,
216
+ required=True,
217
+ )
218
+ parser.add_argument(
219
+ "--data",
220
+ type=Dataset | AnomalibDataModule | DataLoader | str | Path,
221
+ required=True,
222
+ )
223
+ added = parser.add_method_arguments(
224
+ Engine,
225
+ "predict",
226
+ skip={"model", "dataloaders", "datamodule", "dataset", "data_path"},
227
+ )
228
+ self.subcommand_method_arguments["predict"] = added
229
+ self.add_arguments_to_parser(parser)
230
+
231
+ def add_export_arguments(self, parser: ArgumentParser) -> None:
232
+ """Add export arguments to the parser."""
233
+ self._add_default_arguments_to_parser(parser)
234
+ self._add_trainer_arguments_to_parser(parser)
235
+ parser.add_subclass_arguments(
236
+ AnomalyModule,
237
+ "model",
238
+ fail_untyped=False,
239
+ required=True,
240
+ )
241
+ added = parser.add_method_arguments(
242
+ Engine,
243
+ "export",
244
+ skip={"ov_args", "model"},
245
+ )
246
+ self.subcommand_method_arguments["export"] = added
247
+ add_openvino_export_arguments(parser)
248
+ self.add_arguments_to_parser(parser)
249
+
250
+ def _set_install_subcommand(self, action_subcommand: _ActionSubCommands) -> None:
251
+ sub_parser = ArgumentParser(formatter_class=CustomHelpFormatter)
252
+ sub_parser.add_argument(
253
+ "--option",
254
+ help="Install the full or optional-dependencies.",
255
+ default="full",
256
+ type=str,
257
+ choices=["full", "core", "dev", "loggers", "notebooks", "openvino"],
258
+ )
259
+ sub_parser.add_argument(
260
+ "-v",
261
+ "--verbose",
262
+ help="Set Logger level to INFO",
263
+ action="store_true",
264
+ )
265
+
266
+ self.subcommand_parsers["install"] = sub_parser
267
+ action_subcommand.add_subcommand(
268
+ "install",
269
+ sub_parser,
270
+ help="Install the full-package for anomalib.",
271
+ )
272
+
273
+ def before_instantiate_classes(self) -> None:
274
+ """Modify the configuration to properly instantiate classes and sets up tiler."""
275
+ subcommand = self.config["subcommand"]
276
+ if subcommand in (*self.subcommands(), "train", "predict"):
277
+ self.config[subcommand] = update_config(self.config[subcommand])
278
+
279
+ def instantiate_classes(self) -> None:
280
+ """Instantiate classes depending on the subcommand.
281
+
282
+ For trainer related commands it instantiates all the model, datamodule and trainer classes.
283
+ But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc
284
+ This is because the subcommand is responsible for instantiating and executing code based on the passed config
285
+ """
286
+ if self.config["subcommand"] in (*self.subcommands(), "predict"): # trainer commands
287
+ # since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object.
288
+ # the minor change here is that engine is instantiated instead of trainer
289
+ self.config_init = self.parser.instantiate_classes(self.config)
290
+ self.datamodule = self._get(self.config_init, "data")
291
+ if isinstance(self.datamodule, Dataset):
292
+ self.datamodule = DataLoader(self.datamodule)
293
+ self.model = self._get(self.config_init, "model")
294
+ self._configure_optimizers_method_to_model()
295
+ self.instantiate_engine()
296
+ else:
297
+ self.config_init = self.parser.instantiate_classes(self.config)
298
+ subcommand = self.config["subcommand"]
299
+ if subcommand in ("train", "export"):
300
+ self.instantiate_engine()
301
+ if "model" in self.config_init[subcommand]:
302
+ self.model = self._get(self.config_init, "model")
303
+ else:
304
+ self.model = None
305
+ if "data" in self.config_init[subcommand]:
306
+ self.datamodule = self._get(self.config_init, "data")
307
+ else:
308
+ self.datamodule = None
309
+
310
+ def instantiate_engine(self) -> None:
311
+ """Instantiate the engine.
312
+
313
+ .. note::
314
+ Most of the code in this method is taken from ``LightningCLI``'s
315
+ ``instantiate_trainer`` method. Refer to that method for more
316
+ details.
317
+ """
318
+ from lightning.pytorch.cli import SaveConfigCallback
319
+
320
+ from anomalib.callbacks import get_callbacks
321
+
322
+ engine_args = {
323
+ "normalization": self._get(self.config_init, "normalization.normalization_method"),
324
+ "threshold": self._get(self.config_init, "metrics.threshold"),
325
+ "task": self._get(self.config_init, "task"),
326
+ "image_metrics": self._get(self.config_init, "metrics.image"),
327
+ "pixel_metrics": self._get(self.config_init, "metrics.pixel"),
328
+ }
329
+ trainer_config = {**self._get(self.config_init, "trainer", default={}), **engine_args}
330
+ key = "callbacks"
331
+ if key in trainer_config:
332
+ if trainer_config[key] is None:
333
+ trainer_config[key] = []
334
+ elif not isinstance(trainer_config[key], list):
335
+ trainer_config[key] = [trainer_config[key]]
336
+ if not trainer_config.get("fast_dev_run", False):
337
+ config_callback = SaveConfigCallback(
338
+ self._parser(self.subcommand),
339
+ self.config.get(str(self.subcommand), self.config),
340
+ overwrite=True,
341
+ )
342
+ trainer_config[key].append(config_callback)
343
+ trainer_config[key].extend(get_callbacks(self.config[self.subcommand]))
344
+ self.engine = Engine(**trainer_config)
345
+
346
+ def _run_subcommand(self) -> None:
347
+ """Run subcommand depending on the subcommand.
348
+
349
+ This overrides the original ``_run_subcommand`` to run the ``Engine``
350
+ method rather than the ``Train`` method.
351
+ """
352
+ if self.subcommand == "install":
353
+ from anomalib.cli.install import anomalib_install
354
+
355
+ install_kwargs = self.config.get("install", {})
356
+ anomalib_install(**install_kwargs)
357
+ elif self.config["subcommand"] in (*self.subcommands(), "train", "export", "predict"):
358
+ fn = getattr(self.engine, self.subcommand)
359
+ fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand)
360
+ fn(**fn_kwargs)
361
+ else:
362
+ self.config_init = self.parser.instantiate_classes(self.config)
363
+ getattr(self, f"{self.subcommand}")()
364
+
365
+ @property
366
+ def fit(self) -> Callable:
367
+ """Fit the model using engine's fit method."""
368
+ return self.engine.fit
369
+
370
+ @property
371
+ def validate(self) -> Callable:
372
+ """Validate the model using engine's validate method."""
373
+ return self.engine.validate
374
+
375
+ @property
376
+ def test(self) -> Callable:
377
+ """Test the model using engine's test method."""
378
+ return self.engine.test
379
+
380
+ @property
381
+ def predict(self) -> Callable:
382
+ """Predict using engine's predict method."""
383
+ return self.engine.predict
384
+
385
+ @property
386
+ def train(self) -> Callable:
387
+ """Train the model using engine's train method."""
388
+ return self.engine.train
389
+
390
+ @property
391
+ def export(self) -> Callable:
392
+ """Export the model using engine's export method."""
393
+ return self.engine.export
394
+
395
+ def _add_trainer_arguments_to_parser(
396
+ self,
397
+ parser: ArgumentParser,
398
+ add_optimizer: bool = False,
399
+ add_scheduler: bool = False,
400
+ ) -> None:
401
+ """Add trainer arguments to the parser."""
402
+ parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True)
403
+
404
+ if add_optimizer:
405
+ from torch.optim import Optimizer
406
+
407
+ optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
408
+ parser.add_subclass_arguments(
409
+ baseclass=(Optimizer,),
410
+ nested_key="optimizer",
411
+ **optim_kwargs,
412
+ )
413
+ if add_scheduler:
414
+ from lightning.pytorch.cli import LRSchedulerTypeTuple
415
+
416
+ scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
417
+ parser.add_subclass_arguments(
418
+ baseclass=LRSchedulerTypeTuple,
419
+ nested_key="lr_scheduler",
420
+ **scheduler_kwargs,
421
+ )
422
+
423
+ def _add_default_arguments_to_parser(self, parser: ArgumentParser) -> None:
424
+ """Adds default arguments to the parser."""
425
+ parser.add_argument(
426
+ "--seed_everything",
427
+ type=bool | int,
428
+ default=True,
429
+ help=(
430
+ "Set to an int to run seed_everything with this value before classes instantiation."
431
+ "Set to True to use a random seed."
432
+ ),
433
+ )
434
+
435
+ def _get(self, config: Namespace, key: str, default: Any = None) -> Any: # noqa: ANN401
436
+ """Utility to get a config value which might be inside a subcommand."""
437
+ return config.get(str(self.subcommand), config).get(key, default)
438
+
439
+ def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]:
440
+ """Prepares the keyword arguments to pass to the subcommand to run."""
441
+ fn_kwargs = {
442
+ k: v for k, v in self.config_init[subcommand].items() if k in self.subcommand_method_arguments[subcommand]
443
+ }
444
+ fn_kwargs["model"] = self.model
445
+ if self.datamodule is not None:
446
+ if isinstance(self.datamodule, AnomalibDataModule):
447
+ fn_kwargs["datamodule"] = self.datamodule
448
+ elif isinstance(self.datamodule, DataLoader):
449
+ fn_kwargs["dataloaders"] = self.datamodule
450
+ elif isinstance(self.datamodule, Path | str):
451
+ fn_kwargs["data_path"] = self.datamodule
452
+ return fn_kwargs
453
+
454
+ def _parser(self, subcommand: str | None) -> ArgumentParser:
455
+ if subcommand is None:
456
+ return self.parser
457
+ # return the subcommand parser for the subcommand passed
458
+ return self.subcommand_parsers[subcommand]
459
+
460
+ def _configure_optimizers_method_to_model(self) -> None:
461
+ from lightning.pytorch.cli import LightningCLI, instantiate_class
462
+
463
+ optimizer_cfg = self._get(self.config_init, "optimizer", None)
464
+ if optimizer_cfg is None:
465
+ return
466
+ lr_scheduler_cfg = self._get(self.config_init, "lr_scheduler", {})
467
+
468
+ optimizer = instantiate_class(self.model.parameters(), optimizer_cfg)
469
+ lr_scheduler = instantiate_class(optimizer, lr_scheduler_cfg) if lr_scheduler_cfg else None
470
+ fn = partial(LightningCLI.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
471
+
472
+ # override the existing method
473
+ self.model.configure_optimizers = MethodType(fn, self.model)
474
+
475
+
476
+ def main() -> None:
477
+ """Trainer via Anomalib CLI."""
478
+ configure_logger()
479
+ AnomalibCLI()
480
+
481
+
482
+ if __name__ == "__main__":
483
+ main()
anomalib/cli/install.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib install subcommand code."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import logging
7
+
8
+ from pkg_resources import Requirement
9
+ from rich.console import Console
10
+ from rich.logging import RichHandler
11
+
12
+ from anomalib.cli.utils.installation import (
13
+ get_requirements,
14
+ get_torch_install_args,
15
+ parse_requirements,
16
+ )
17
+
18
+ logger = logging.getLogger("pip")
19
+ logger.setLevel(logging.WARNING) # setLevel: CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
20
+ console = Console()
21
+ handler = RichHandler(
22
+ console=console,
23
+ show_level=False,
24
+ show_path=False,
25
+ )
26
+ logger.addHandler(handler)
27
+
28
+
29
+ def anomalib_install(option: str = "full", verbose: bool = False) -> int:
30
+ """Install Anomalib requirements.
31
+
32
+ Args:
33
+ option (str | None): Optional-dependency to install requirements for.
34
+ verbose (bool): Set pip logger level to INFO
35
+
36
+ Raises:
37
+ ValueError: When the task is not supported.
38
+
39
+ Returns:
40
+ int: Status code of the pip install command.
41
+ """
42
+ from pip._internal.commands import create_command
43
+
44
+ requirements_dict = get_requirements("anomalib")
45
+
46
+ requirements = []
47
+ if option == "full":
48
+ for extra in requirements_dict:
49
+ requirements.extend(requirements_dict[extra])
50
+ elif option in requirements_dict:
51
+ requirements.extend(requirements_dict[option])
52
+ elif option is not None:
53
+ requirements.append(Requirement.parse(option))
54
+
55
+ # Parse requirements into torch and other requirements.
56
+ # This is done to parse the correct version of torch (cpu/cuda).
57
+ torch_requirement, other_requirements = parse_requirements(requirements, skip_torch=option not in ("full", "core"))
58
+
59
+ # Get install args for torch to install it from a specific index-url
60
+ install_args: list[str] = []
61
+ torch_install_args = []
62
+ if option in ("full", "core") and torch_requirement is not None:
63
+ torch_install_args = get_torch_install_args(torch_requirement)
64
+
65
+ # Combine torch and other requirements.
66
+ install_args = other_requirements + torch_install_args
67
+
68
+ # Install requirements.
69
+ with console.status("[bold green]Installing packages... This may take a few minutes.\n") as status:
70
+ if verbose:
71
+ logger.setLevel(logging.INFO)
72
+ status.stop()
73
+ console.log(f"Installation list: [yellow]{install_args}[/yellow]")
74
+ status_code = create_command("install").main(install_args)
75
+ if status_code == 0:
76
+ console.log(f"Installation Complete: {install_args}")
77
+
78
+ if status_code == 0:
79
+ console.print("Anomalib Installation [bold green]Complete.[/bold green]")
80
+
81
+ return status_code
anomalib/cli/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Anomalib CLI Utils."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from .help_formatter import CustomHelpFormatter
7
+
8
+ __all__ = ["CustomHelpFormatter"]
anomalib/cli/utils/help_formatter.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Help Formatters for Anomalib CLI."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import argparse
7
+ import re
8
+ import sys
9
+ from typing import TypeVar
10
+
11
+ import docstring_parser
12
+ from jsonargparse import DefaultHelpFormatter
13
+ from rich.markdown import Markdown
14
+ from rich.panel import Panel
15
+ from rich_argparse import RichHelpFormatter
16
+
17
+ REQUIRED_ARGUMENTS = {
18
+ "train": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
19
+ "fit": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
20
+ "validate": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
21
+ "test": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
22
+ "predict": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
23
+ "export": {"model", "model.help", "export_type", "ckpt_path", "config"},
24
+ }
25
+
26
+ try:
27
+ from anomalib.engine import Engine
28
+
29
+ DOCSTRING_USAGE = {
30
+ "train": Engine.train,
31
+ "fit": Engine.fit,
32
+ "validate": Engine.validate,
33
+ "test": Engine.test,
34
+ "predict": Engine.predict,
35
+ "export": Engine.export,
36
+ }
37
+ except ImportError:
38
+ print("To use other subcommand using `anomalib install`")
39
+
40
+
41
+ def get_short_docstring(component: TypeVar) -> str:
42
+ """Get the short description from the docstring.
43
+
44
+ Args:
45
+ component (TypeVar): The component to get the docstring from
46
+
47
+ Returns:
48
+ str: The short description
49
+ """
50
+ if component.__doc__ is None:
51
+ return ""
52
+ docstring = docstring_parser.parse(component.__doc__)
53
+ return docstring.short_description
54
+
55
+
56
+ def get_verbosity_subcommand() -> dict:
57
+ """Parse command line arguments and returns a dictionary of key-value pairs.
58
+
59
+ Returns:
60
+ A dictionary containing the parsed command line arguments.
61
+
62
+ Examples:
63
+ >>> import sys
64
+ >>> sys.argv = ['anomalib', 'train', '-h', '-v']
65
+ >>> get_verbosity_subcommand()
66
+ {'subcommand': 'train', 'help': True, 'verbosity': 1}
67
+ """
68
+ arguments: dict = {"subcommand": None, "help": False, "verbosity": 2}
69
+ if len(sys.argv) >= 2 and sys.argv[1] not in ("--help", "-h"):
70
+ arguments["subcommand"] = sys.argv[1]
71
+ if "--help" in sys.argv or "-h" in sys.argv:
72
+ arguments["help"] = True
73
+ if arguments["subcommand"] in REQUIRED_ARGUMENTS:
74
+ arguments["verbosity"] = 0
75
+ if "-v" in sys.argv or "--verbose" in sys.argv:
76
+ arguments["verbosity"] = 1
77
+ if "-vv" in sys.argv:
78
+ arguments["verbosity"] = 2
79
+ return arguments
80
+
81
+
82
+ def get_intro() -> Markdown:
83
+ """Return a Markdown object containing the introduction text for Anomalib CLI Guide.
84
+
85
+ The introduction text includes a brief description of the guide and links to the Github repository and documentation
86
+
87
+ Returns:
88
+ A Markdown object containing the introduction text for Anomalib CLI Guide.
89
+ """
90
+ intro_markdown = (
91
+ "# Anomalib CLI Guide\n\n"
92
+ "Github Repository: [https://github.com/openvinotoolkit/anomalib](https://github.com/openvinotoolkit/anomalib)."
93
+ "\n\n"
94
+ "A better guide is provided by the [documentation](https://anomalib.readthedocs.io/en/latest/index.html)."
95
+ )
96
+ return Markdown(intro_markdown)
97
+
98
+
99
+ def get_verbose_usage(subcommand: str = "train") -> str:
100
+ """Return a string containing verbose usage information for the specified subcommand.
101
+
102
+ Args:
103
+ ----
104
+ subcommand (str): The name of the subcommand to get verbose usage information for. Defaults to "train".
105
+
106
+ Returns:
107
+ -------
108
+ str: A string containing verbose usage information for the specified subcommand.
109
+ """
110
+ return (
111
+ "To get more overridable argument information, run the command below.\n"
112
+ "```python\n"
113
+ "# Verbosity Level 1\n"
114
+ f"anomalib {subcommand} [optional_arguments] -h -v\n"
115
+ "# Verbosity Level 2\n"
116
+ f"anomalib {subcommand} [optional_arguments] -h -vv\n"
117
+ "```"
118
+ )
119
+
120
+
121
+ def get_cli_usage_docstring(component: object | None) -> str | None:
122
+ r"""Get the cli usage from the docstring.
123
+
124
+ Args:
125
+ ----
126
+ component (Optional[object]): The component to get the docstring from
127
+
128
+ Returns:
129
+ -------
130
+ Optional[str]: The quick-start guide as Markdown format.
131
+
132
+ Example:
133
+ -------
134
+ component.__doc__ = '''
135
+ <Prev Section>
136
+
137
+ CLI Usage:
138
+ 1. First Step.
139
+ 2. Second Step.
140
+
141
+ <Next Section>
142
+ '''
143
+ >>> get_cli_usage_docstring(component)
144
+ "1. First Step.\n2. Second Step."
145
+ """
146
+ if component is None or component.__doc__ is None or "CLI Usage" not in component.__doc__:
147
+ return None
148
+
149
+ pattern = r"CLI Usage:(.*?)(?=\n{2,}|\Z)"
150
+ match = re.search(pattern, component.__doc__, re.DOTALL)
151
+
152
+ if match:
153
+ contents = match.group(1).strip().split("\n")
154
+ return "\n".join([content.strip() for content in contents])
155
+ return None
156
+
157
+
158
+ def render_guide(subcommand: str | None = None) -> list:
159
+ """Render a guide for the specified subcommand.
160
+
161
+ Args:
162
+ ----
163
+ subcommand (Optional[str]): The subcommand to render the guide for.
164
+
165
+ Returns:
166
+ -------
167
+ list: A list of contents to be displayed in the guide.
168
+ """
169
+ if subcommand is None or subcommand not in DOCSTRING_USAGE:
170
+ return []
171
+ contents = [get_intro()]
172
+ target_command = DOCSTRING_USAGE[subcommand]
173
+ cli_usage = get_cli_usage_docstring(target_command)
174
+ if cli_usage is not None:
175
+ cli_usage += f"\n{get_verbose_usage(subcommand)}"
176
+ quick_start = Panel(Markdown(cli_usage), border_style="dim", title="Quick-Start", title_align="left")
177
+ contents.append(quick_start)
178
+ return contents
179
+
180
+
181
+ class CustomHelpFormatter(RichHelpFormatter, DefaultHelpFormatter):
182
+ """A custom help formatter for Anomalib CLI.
183
+
184
+ This formatter extends the RichHelpFormatter and DefaultHelpFormatter classes to provide
185
+ a more detailed and customizable help output for Anomalib CLI.
186
+
187
+ Attributes:
188
+ verbosity_level : int
189
+ The level of verbosity for the help output.
190
+ subcommand : str | None
191
+ The subcommand to render the guide for.
192
+
193
+ Methods:
194
+ add_usage(usage, actions, *args, **kwargs)
195
+ Add usage information to the help output.
196
+ add_argument(action)
197
+ Add an argument to the help output.
198
+ format_help()
199
+ Format the help output.
200
+ """
201
+
202
+ verbosity_dict = get_verbosity_subcommand()
203
+ verbosity_level = verbosity_dict["verbosity"]
204
+ subcommand = verbosity_dict["subcommand"]
205
+
206
+ def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None:
207
+ """Add usage information to the formatter.
208
+
209
+ Args:
210
+ ----
211
+ usage (str | None): A string describing the usage of the program.
212
+ actions (list): An list of argparse.Action objects.
213
+ *args (Any): Additional positional arguments to pass to the superclass method.
214
+ **kwargs (Any): Additional keyword arguments to pass to the superclass method.
215
+
216
+ Returns:
217
+ -------
218
+ None
219
+ """
220
+ if self.subcommand in REQUIRED_ARGUMENTS:
221
+ if self.verbosity_level == 0:
222
+ actions = []
223
+ elif self.verbosity_level == 1:
224
+ actions = [action for action in actions if action.dest in REQUIRED_ARGUMENTS[self.subcommand]]
225
+
226
+ super().add_usage(usage, actions, *args, **kwargs)
227
+
228
+ def add_argument(self, action: argparse.Action) -> None:
229
+ """Add an argument to the help formatter.
230
+
231
+ If the verbose level is set to 0, the argument is not added.
232
+ If the verbose level is set to 1 and the argument is not in the non-skip list, the argument is not added.
233
+
234
+ Args:
235
+ ----
236
+ action (argparse.Action): The action to add to the help formatter.
237
+ """
238
+ if self.subcommand in REQUIRED_ARGUMENTS:
239
+ if self.verbosity_level == 0:
240
+ return
241
+ if self.verbosity_level == 1 and action.dest not in REQUIRED_ARGUMENTS[self.subcommand]:
242
+ return
243
+ super().add_argument(action)
244
+
245
+ def format_help(self) -> str:
246
+ """Format the help message for the current command and returns it as a string.
247
+
248
+ The help message includes information about the command's arguments and options,
249
+ as well as any additional information provided by the command's help guide.
250
+
251
+ Returns:
252
+ str: A string containing the formatted help message.
253
+ """
254
+ with self.console.capture() as capture:
255
+ section = self._root_section
256
+ if self.subcommand in REQUIRED_ARGUMENTS and self.verbosity_level in (0, 1) and len(section.rich_items) > 1:
257
+ contents = render_guide(self.subcommand)
258
+ for content in contents:
259
+ self.console.print(content)
260
+ if self.verbosity_level > 0:
261
+ if len(section.rich_items) > 1:
262
+ section = Panel(section, border_style="dim", title="Arguments", title_align="left")
263
+ self.console.print(section, highlight=False, soft_wrap=True)
264
+ help_msg = capture.get()
265
+
266
+ if help_msg:
267
+ help_msg = self._long_break_matcher.sub("\n\n", help_msg).rstrip() + "\n"
268
+ return help_msg
anomalib/cli/utils/installation.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib installation util functions."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ import platform
11
+ import re
12
+ from importlib.metadata import requires
13
+ from pathlib import Path
14
+ from warnings import warn
15
+
16
+ from pkg_resources import Requirement
17
+
18
+ AVAILABLE_TORCH_VERSIONS = {
19
+ "2.0.0": {"torchvision": "0.15.1", "cuda": ("11.7", "11.8")},
20
+ "2.0.1": {"torchvision": "0.15.2", "cuda": ("11.7", "11.8")},
21
+ "2.1.1": {"torchvision": "0.16.1", "cuda": ("11.8", "12.1")},
22
+ "2.1.2": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")},
23
+ "2.2.0": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")},
24
+ }
25
+
26
+
27
+ def get_requirements(module: str = "anomalib") -> dict[str, list[Requirement]]:
28
+ """Get requirements of module from importlib.metadata.
29
+
30
+ This function returns list of required packages from importlib_metadata.
31
+
32
+ Example:
33
+ >>> get_requirements("anomalib")
34
+ {
35
+ "base": ["jsonargparse==4.27.1", ...],
36
+ "core": ["torch==2.1.1", ...],
37
+ ...
38
+ }
39
+
40
+ Returns:
41
+ dict[str, list[Requirement]]: List of required packages for each optional-extras.
42
+ """
43
+ requirement_list: list[str] | None = requires(module)
44
+ extra_requirement: dict[str, list[Requirement]] = {}
45
+ if requirement_list is None:
46
+ return extra_requirement
47
+ for requirement in requirement_list:
48
+ extra = "core"
49
+ requirement_extra: list[str] = requirement.replace(" ", "").split(";")
50
+ if isinstance(requirement_extra, list) and len(requirement_extra) > 1:
51
+ extra = requirement_extra[-1].split("==")[-1].strip("'\"")
52
+ _requirement_name = requirement_extra[0]
53
+ _requirement = Requirement.parse(_requirement_name)
54
+ if extra in extra_requirement:
55
+ extra_requirement[extra].append(_requirement)
56
+ else:
57
+ extra_requirement[extra] = [_requirement]
58
+ return extra_requirement
59
+
60
+
61
+ def parse_requirements(
62
+ requirements: list[Requirement],
63
+ skip_torch: bool = False,
64
+ ) -> tuple[str | None, list[str]]:
65
+ """Parse requirements and returns torch and other requirements.
66
+
67
+ Args:
68
+ requirements (list[Requirement]): List of requirements.
69
+ skip_torch (bool): Whether to skip torch requirement. Defaults to False.
70
+
71
+ Raises:
72
+ ValueError: If torch requirement is not found.
73
+
74
+ Examples:
75
+ >>> requirements = [
76
+ ... Requirement.parse("torch==1.13.0"),
77
+ ... Requirement.parse("onnx>=1.8.1"),
78
+ ... ]
79
+ >>> parse_requirements(requirements=requirements)
80
+ (Requirement.parse("torch==1.13.0"),
81
+ Requirement.parse("onnx>=1.8.1"))
82
+
83
+ Returns:
84
+ tuple[str, list[str], list[str]]: Tuple of torch and other requirements.
85
+ """
86
+ torch_requirement: str | None = None
87
+ other_requirements: list[str] = []
88
+
89
+ for requirement in requirements:
90
+ if requirement.unsafe_name == "torch":
91
+ torch_requirement = str(requirement)
92
+ if len(requirement.specs) > 1:
93
+ warn(
94
+ "requirements.txt contains. Please remove other versions of torch from requirements.",
95
+ stacklevel=2,
96
+ )
97
+
98
+ # Rest of the requirements are task requirements.
99
+ # Other torch-related requirements such as `torchvision` are to be excluded.
100
+ # This is because torch-related requirements are already handled in torch_requirement.
101
+ else:
102
+ # if not requirement.unsafe_name.startswith("torch"):
103
+ other_requirements.append(str(requirement))
104
+
105
+ if not skip_torch and not torch_requirement:
106
+ msg = "Could not find torch requirement. Anoamlib depends on torch. Please add torch to your requirements."
107
+ raise ValueError(msg)
108
+
109
+ # Get the unique list of the requirements.
110
+ other_requirements = list(set(other_requirements))
111
+
112
+ return torch_requirement, other_requirements
113
+
114
+
115
+ def get_cuda_version() -> str | None:
116
+ """Get CUDA version installed on the system.
117
+
118
+ Examples:
119
+ >>> # Assume that CUDA version is 11.2
120
+ >>> get_cuda_version()
121
+ "11.2"
122
+
123
+ >>> # Assume that CUDA is not installed on the system
124
+ >>> get_cuda_version()
125
+ None
126
+
127
+ Returns:
128
+ str | None: CUDA version installed on the system.
129
+ """
130
+ # 1. Check CUDA_HOME Environment variable
131
+ cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda")
132
+
133
+ if Path(cuda_home).exists():
134
+ # Check $CUDA_HOME/version.json file.
135
+ version_file = Path(cuda_home) / "version.json"
136
+ if version_file.is_file():
137
+ with Path(version_file).open() as file:
138
+ data = json.load(file)
139
+ cuda_version = data.get("cuda", {}).get("version", None)
140
+ if cuda_version is not None:
141
+ cuda_version_parts = cuda_version.split(".")
142
+ return ".".join(cuda_version_parts[:2])
143
+ # 2. 'nvcc --version' check & without version.json case
144
+ try:
145
+ result = os.popen(cmd="nvcc --version")
146
+ output = result.read()
147
+
148
+ cuda_version_pattern = r"cuda_(\d+\.\d+)"
149
+ cuda_version_match = re.search(cuda_version_pattern, output)
150
+
151
+ if cuda_version_match is not None:
152
+ return cuda_version_match.group(1)
153
+ except OSError:
154
+ msg = "Could not find cuda-version. Instead, the CPU version of torch will be installed."
155
+ warn(msg, stacklevel=2)
156
+ return None
157
+
158
+
159
+ def update_cuda_version_with_available_torch_cuda_build(cuda_version: str, torch_version: str) -> str:
160
+ """Update the installed CUDA version with the highest supported CUDA version by PyTorch.
161
+
162
+ Args:
163
+ cuda_version (str): The installed CUDA version.
164
+ torch_version (str): The PyTorch version.
165
+
166
+ Raises:
167
+ Warning: If the installed CUDA version is not supported by PyTorch.
168
+
169
+ Examples:
170
+ >>> update_cuda_version_with_available_torch_cuda_builds("11.1", "1.13.0")
171
+ "11.6"
172
+
173
+ >>> update_cuda_version_with_available_torch_cuda_builds("11.7", "1.13.0")
174
+ "11.7"
175
+
176
+ >>> update_cuda_version_with_available_torch_cuda_builds("11.8", "1.13.0")
177
+ "11.7"
178
+
179
+ >>> update_cuda_version_with_available_torch_cuda_builds("12.1", "2.0.1")
180
+ "11.8"
181
+
182
+ Returns:
183
+ str: The updated CUDA version.
184
+ """
185
+ max_supported_cuda = max(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"])
186
+ min_supported_cuda = min(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"])
187
+ bounded_cuda_version = max(min(cuda_version, max_supported_cuda), min_supported_cuda)
188
+
189
+ if cuda_version != bounded_cuda_version:
190
+ warn(
191
+ f"Installed CUDA version is v{cuda_version}. \n"
192
+ f"v{min_supported_cuda} <= Supported CUDA version <= v{max_supported_cuda}.\n"
193
+ f"This script will use CUDA v{bounded_cuda_version}.\n"
194
+ f"However, this may not be safe, and you are advised to install the correct version of CUDA.\n"
195
+ f"For more details, refer to https://pytorch.org/get-started/locally/",
196
+ stacklevel=2,
197
+ )
198
+ cuda_version = bounded_cuda_version
199
+
200
+ return cuda_version
201
+
202
+
203
+ def get_cuda_suffix(cuda_version: str) -> str:
204
+ """Get CUDA suffix for PyTorch versions.
205
+
206
+ Args:
207
+ cuda_version (str): CUDA version installed on the system.
208
+
209
+ Note:
210
+ The CUDA version of PyTorch is not always the same as the CUDA version
211
+ that is installed on the system. For example, the latest PyTorch
212
+ version (1.10.0) supports CUDA 11.3, but the latest CUDA version
213
+ that is available for download is 11.2. Therefore, we need to use
214
+ the latest available CUDA version for PyTorch instead of the CUDA
215
+ version that is installed on the system. Therefore, this function
216
+ shoudl be regularly updated to reflect the latest available CUDA.
217
+
218
+ Examples:
219
+ >>> get_cuda_suffix(cuda_version="11.2")
220
+ "cu112"
221
+
222
+ >>> get_cuda_suffix(cuda_version="11.8")
223
+ "cu118"
224
+
225
+ Returns:
226
+ str: CUDA suffix for PyTorch or mmX version.
227
+ """
228
+ return f"cu{cuda_version.replace('.', '')}"
229
+
230
+
231
+ def get_hardware_suffix(with_available_torch_build: bool = False, torch_version: str | None = None) -> str:
232
+ """Get hardware suffix for PyTorch or mmX versions.
233
+
234
+ Args:
235
+ with_available_torch_build (bool): Whether to use the latest available
236
+ PyTorch build or not. If True, the latest available PyTorch build
237
+ will be used. If False, the installed PyTorch build will be used.
238
+ Defaults to False.
239
+ torch_version (str | None): PyTorch version. This is only used when the
240
+ ``with_available_torch_build`` is True.
241
+
242
+ Examples:
243
+ >>> # Assume that CUDA version is 11.2
244
+ >>> get_hardware_suffix()
245
+ "cu112"
246
+
247
+ >>> # Assume that CUDA is not installed on the system
248
+ >>> get_hardware_suffix()
249
+ "cpu"
250
+
251
+ Assume that that installed CUDA version is 12.1.
252
+ However, the latest available CUDA version for PyTorch v2.0 is 11.8.
253
+ Therefore, we use 11.8 instead of 12.1. This is because PyTorch does not
254
+ support CUDA 12.1 yet. In this case, we could correct the CUDA version
255
+ by setting `with_available_torch_build` to True.
256
+
257
+ >>> cuda_version = get_cuda_version()
258
+ "12.1"
259
+ >>> get_hardware_suffix(with_available_torch_build=True, torch_version="2.0.1")
260
+ "cu118"
261
+
262
+ Returns:
263
+ str: Hardware suffix for PyTorch or mmX version.
264
+ """
265
+ cuda_version = get_cuda_version()
266
+ if cuda_version:
267
+ if with_available_torch_build:
268
+ if torch_version is None:
269
+ msg = "``torch_version`` must be provided when with_available_torch_build is True."
270
+ raise ValueError(msg)
271
+ cuda_version = update_cuda_version_with_available_torch_cuda_build(cuda_version, torch_version)
272
+ hardware_suffix = get_cuda_suffix(cuda_version)
273
+ else:
274
+ hardware_suffix = "cpu"
275
+
276
+ return hardware_suffix
277
+
278
+
279
+ def add_hardware_suffix_to_torch(
280
+ requirement: Requirement,
281
+ hardware_suffix: str | None = None,
282
+ with_available_torch_build: bool = False,
283
+ ) -> str:
284
+ """Add hardware suffix to the torch requirement.
285
+
286
+ Args:
287
+ requirement (Requirement): Requirement object comprising requirement
288
+ details.
289
+ hardware_suffix (str | None): Hardware suffix. If None, it will be set
290
+ to the correct hardware suffix. Defaults to None.
291
+ with_available_torch_build (bool): To check whether the installed
292
+ CUDA version is supported by the latest available PyTorch build.
293
+ Defaults to False.
294
+
295
+ Examples:
296
+ >>> from pkg_resources import Requirement
297
+ >>> req = "torch>=1.13.0, <=2.0.1"
298
+ >>> requirement = Requirement.parse(req)
299
+ >>> requirement.name, requirement.specs
300
+ ('torch', [('>=', '1.13.0'), ('<=', '2.0.1')])
301
+
302
+ >>> add_hardware_suffix_to_torch(requirement)
303
+ 'torch>=1.13.0+cu121, <=2.0.1+cu121'
304
+
305
+ ``with_available_torch_build=True`` will use the latest available PyTorch build.
306
+ >>> req = "torch==2.0.1"
307
+ >>> requirement = Requirement.parse(req)
308
+ >>> add_hardware_suffix_to_torch(requirement, with_available_torch_build=True)
309
+ 'torch==2.0.1+cu118'
310
+
311
+ It is possible to pass the ``hardware_suffix`` manually.
312
+ >>> req = "torch==2.0.1"
313
+ >>> requirement = Requirement.parse(req)
314
+ >>> add_hardware_suffix_to_torch(requirement, hardware_suffix="cu121")
315
+ 'torch==2.0.1+cu111'
316
+
317
+ Raises:
318
+ ValueError: When the requirement has more than two version criterion.
319
+
320
+ Returns:
321
+ str: Updated torch package with the right cuda suffix.
322
+ """
323
+ name = requirement.unsafe_name
324
+ updated_specs: list[str] = []
325
+
326
+ for operator, version in requirement.specs:
327
+ hardware_suffix = hardware_suffix or get_hardware_suffix(with_available_torch_build, version)
328
+ updated_version = version + f"+{hardware_suffix}" if not version.startswith(("2.1", "2.2")) else version
329
+
330
+ # ``specs`` contains operators and versions as follows:
331
+ # These are to be concatenated again for the updated version.
332
+ updated_specs.append(operator + updated_version)
333
+
334
+ updated_requirement: str = ""
335
+
336
+ if updated_specs:
337
+ # This is the case when specs are e.g. ['<=1.9.1+cu111']
338
+ if len(updated_specs) == 1:
339
+ updated_requirement = name + updated_specs[0]
340
+ # This is the case when specs are e.g., ['<=1.9.1+cu111', '>=1.8.1+cu111']
341
+ elif len(updated_specs) == 2:
342
+ updated_requirement = name + updated_specs[0] + ", " + updated_specs[1]
343
+ else:
344
+ msg = (
345
+ "Requirement version can be a single value or a range. \n"
346
+ "For example it could be torch>=1.8.1 "
347
+ "or torch>=1.8.1, <=1.9.1\n"
348
+ f"Got {updated_specs} instead."
349
+ )
350
+ raise ValueError(msg)
351
+ return updated_requirement
352
+
353
+
354
+ def get_torch_install_args(requirement: str | Requirement) -> list[str]:
355
+ """Get the install arguments for Torch requirement.
356
+
357
+ This function will return the install arguments for the Torch requirement
358
+ and its corresponding torchvision requirement.
359
+
360
+ Args:
361
+ requirement (str | Requirement): The torch requirement.
362
+
363
+ Raises:
364
+ RuntimeError: If the OS is not supported.
365
+
366
+ Example:
367
+ >>> from pkg_resources import Requirement
368
+ >>> requriment = "torch>=1.13.0"
369
+ >>> get_torch_install_args(requirement)
370
+ ['--extra-index-url', 'https://download.pytorch.org/whl/cpu',
371
+ 'torch==1.13.0+cpu', 'torchvision==0.14.0+cpu']
372
+
373
+ Returns:
374
+ list[str]: The install arguments.
375
+ """
376
+ if isinstance(requirement, str):
377
+ requirement = Requirement.parse(requirement)
378
+
379
+ # NOTE: This does not take into account if the requirement has multiple versions
380
+ # such as torch<2.0.1,>=1.13.0
381
+ if len(requirement.specs) < 1:
382
+ return [str(requirement)]
383
+ select_spec_idx = 0
384
+ for i, spec in enumerate(requirement.specs):
385
+ if "=" in spec[0]:
386
+ select_spec_idx = i
387
+ break
388
+ operator, version = requirement.specs[select_spec_idx]
389
+ if version not in AVAILABLE_TORCH_VERSIONS:
390
+ version = max(AVAILABLE_TORCH_VERSIONS.keys())
391
+ warn(
392
+ f"Torch Version will be selected as {version}.",
393
+ stacklevel=2,
394
+ )
395
+ install_args: list[str] = []
396
+
397
+ if platform.system() in ("Linux", "Windows"):
398
+ # Get the hardware suffix (eg., +cpu, +cu116 and +cu118 etc.)
399
+ hardware_suffix = get_hardware_suffix(with_available_torch_build=True, torch_version=version)
400
+
401
+ # Create the PyTorch Index URL to download the correct wheel.
402
+ index_url = f"https://download.pytorch.org/whl/{hardware_suffix}"
403
+
404
+ # Create the PyTorch version depending on the CUDA version. For example,
405
+ # If CUDA version is 11.2, then the PyTorch version is 1.8.0+cu112.
406
+ # If CUDA version is None, then the PyTorch version is 1.8.0+cpu.
407
+ torch_version = add_hardware_suffix_to_torch(requirement, hardware_suffix, with_available_torch_build=True)
408
+
409
+ # Get the torchvision version depending on the torch version.
410
+ torchvision_version = AVAILABLE_TORCH_VERSIONS[version]["torchvision"]
411
+ torchvision_requirement = f"torchvision{operator}{torchvision_version}"
412
+ if isinstance(torchvision_version, str) and not torchvision_version.startswith("0.16"):
413
+ torchvision_requirement += f"+{hardware_suffix}"
414
+
415
+ # Return the install arguments.
416
+ install_args += [
417
+ "--extra-index-url",
418
+ # "--index-url",
419
+ index_url,
420
+ torch_version,
421
+ torchvision_requirement,
422
+ ]
423
+ elif platform.system() in ("macos", "Darwin"):
424
+ torch_version = str(requirement)
425
+ install_args += [torch_version]
426
+ else:
427
+ msg = f"Unsupported OS: {platform.system()}"
428
+ raise RuntimeError(msg)
429
+
430
+ return install_args
anomalib/cli/utils/openvino.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for OpenVINO parser."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import logging
7
+
8
+ from jsonargparse import ArgumentParser
9
+
10
+ from anomalib.utils.exceptions import try_import
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ if try_import("openvino"):
16
+ from openvino.tools.ovc.cli_parser import get_common_cli_parser
17
+ else:
18
+ get_common_cli_parser = None
19
+
20
+
21
+ def add_openvino_export_arguments(parser: ArgumentParser) -> None:
22
+ """Add OpenVINO arguments to parser under --mo key."""
23
+ if get_common_cli_parser is not None:
24
+ group = parser.add_argument_group("OpenVINO Model Optimizer arguments (optional)")
25
+ ov_parser = get_common_cli_parser()
26
+ # remove redundant keys from mo keys
27
+ for arg in ov_parser._actions: # noqa: SLF001
28
+ if arg.dest in ("help", "input_model", "output_dir"):
29
+ continue
30
+ group.add_argument(f"--ov_args.{arg.dest}", type=arg.type, default=arg.default, help=arg.help)
31
+ else:
32
+ logger.info("OpenVINO is possibly not installed in the environment. Skipping adding it to parser.")
anomalib/data/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib Datasets."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import importlib
8
+ import logging
9
+ from enum import Enum
10
+ from itertools import chain
11
+
12
+ from omegaconf import DictConfig, ListConfig
13
+
14
+ from anomalib.utils.config import to_tuple
15
+
16
+ from .base import AnomalibDataModule, AnomalibDataset
17
+ from .depth import DepthDataFormat, Folder3D, MVTec3D
18
+ from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa
19
+ from .predict import PredictDataset
20
+ from .utils import LabelName
21
+ from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ DataFormat = Enum( # type: ignore[misc]
27
+ "DataFormat",
28
+ {i.name: i.value for i in chain(DepthDataFormat, ImageDataFormat, VideoDataFormat)},
29
+ )
30
+
31
+
32
+ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
33
+ """Get Anomaly Datamodule.
34
+
35
+ Args:
36
+ config (DictConfig | ListConfig): Configuration of the anomaly model.
37
+
38
+ Returns:
39
+ PyTorch Lightning DataModule
40
+ """
41
+ logger.info("Loading the datamodule")
42
+
43
+ module = importlib.import_module(".".join(config.data.class_path.split(".")[:-1]))
44
+ dataclass = getattr(module, config.data.class_path.split(".")[-1])
45
+ init_args = {**config.data.get("init_args", {})} # get dict
46
+ if "image_size" in init_args:
47
+ init_args["image_size"] = to_tuple(init_args["image_size"])
48
+
49
+ return dataclass(**init_args)
50
+
51
+
52
+ __all__ = [
53
+ "AnomalibDataset",
54
+ "AnomalibDataModule",
55
+ "DepthDataFormat",
56
+ "ImageDataFormat",
57
+ "VideoDataFormat",
58
+ "get_datamodule",
59
+ "BTech",
60
+ "Folder",
61
+ "Folder3D",
62
+ "PredictDataset",
63
+ "Kolektor",
64
+ "MVTec",
65
+ "MVTec3D",
66
+ "MVTecLoco",
67
+ "Avenue",
68
+ "UCSDped",
69
+ "ShanghaiTech",
70
+ "Visa",
71
+ "LabelName",
72
+ ]
anomalib/data/base/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base classes for custom dataset and datamodules."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from .datamodule import AnomalibDataModule
8
+ from .dataset import AnomalibDataset
9
+ from .depth import AnomalibDepthDataset
10
+ from .video import AnomalibVideoDataModule, AnomalibVideoDataset
11
+
12
+ __all__ = [
13
+ "AnomalibDataset",
14
+ "AnomalibDataModule",
15
+ "AnomalibVideoDataset",
16
+ "AnomalibVideoDataModule",
17
+ "AnomalibDepthDataset",
18
+ ]
anomalib/data/base/datamodule.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib datamodule base class."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import logging
8
+ from abc import ABC, abstractmethod
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from lightning.pytorch import LightningDataModule
12
+ from lightning.pytorch.trainer.states import TrainerFn
13
+ from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
14
+ from torch.utils.data.dataloader import DataLoader, default_collate
15
+ from torchvision.transforms.v2 import Resize, Transform
16
+
17
+ from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label
18
+ from anomalib.data.utils.synthetic import SyntheticAnomalyDataset
19
+
20
+ if TYPE_CHECKING:
21
+ from pandas import DataFrame
22
+
23
+ from anomalib.data.base.dataset import AnomalibDataset
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def collate_fn(batch: list) -> dict[str, Any]:
29
+ """Collate bounding boxes as lists.
30
+
31
+ Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists,
32
+ the `mask_path` is also collated as a list since each element in the batch could be unequal.
33
+ For all other entries, the default collate function is used.
34
+
35
+ Args:
36
+ batch (List): list of items in the batch where len(batch) is equal to the batch size.
37
+
38
+ Returns:
39
+ dict[str, Any]: Dictionary containing the collated batch information.
40
+ """
41
+ elem = batch[0] # sample an element from the batch to check the type.
42
+ out_dict = {}
43
+ if isinstance(elem, dict):
44
+ if "boxes" in elem:
45
+ # collate boxes as list
46
+ out_dict["boxes"] = [item.pop("boxes") for item in batch]
47
+ if "semantic_mask" in elem:
48
+ # semantic masks have a variable number of channels, so we collate them as a list
49
+ out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch]
50
+ if "mask_path" in elem and isinstance(elem["mask_path"], list):
51
+ # collate mask paths as list
52
+ out_dict["mask_path"] = [item.pop("mask_path") for item in batch]
53
+ # collate other data normally
54
+ out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
55
+ return out_dict
56
+ return default_collate(batch)
57
+
58
+
59
+ class AnomalibDataModule(LightningDataModule, ABC):
60
+ """Base Anomalib data module.
61
+
62
+ Args:
63
+ train_batch_size (int): Batch size used by the train dataloader.
64
+ eval_batch_size (int): Batch size used by the val and test dataloaders.
65
+ num_workers (int): Number of workers used by the train, val and test dataloaders.
66
+ val_split_mode (ValSplitMode): Determines how the validation split is obtained.
67
+ Options: [none, same_as_test, from_test, synthetic]
68
+ val_split_ratio (float): Fraction of the train or test images held our for validation.
69
+ test_split_mode (Optional[TestSplitMode], optional): Determines how the test split is obtained.
70
+ Options: [none, from_dir, synthetic].
71
+ Defaults to ``None``.
72
+ test_split_ratio (float): Fraction of the train images held out for testing.
73
+ Defaults to ``None``.
74
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
75
+ Defaults to ``None``.
76
+ transform (Transform, optional): Transforms that should be applied to the input images.
77
+ Defaults to ``None``.
78
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
79
+ Defaults to ``None``.
80
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
81
+ Defaults to ``None``.
82
+ seed (int | None, optional): Seed used during random subset splitting.
83
+ Defaults to ``None``.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ train_batch_size: int,
89
+ eval_batch_size: int,
90
+ num_workers: int,
91
+ val_split_mode: ValSplitMode | str,
92
+ val_split_ratio: float,
93
+ test_split_mode: TestSplitMode | str | None = None,
94
+ test_split_ratio: float | None = None,
95
+ image_size: tuple[int, int] | None = None,
96
+ transform: Transform | None = None,
97
+ train_transform: Transform | None = None,
98
+ eval_transform: Transform | None = None,
99
+ seed: int | None = None,
100
+ ) -> None:
101
+ super().__init__()
102
+ self.train_batch_size = train_batch_size
103
+ self.eval_batch_size = eval_batch_size
104
+ self.num_workers = num_workers
105
+ self.test_split_mode = TestSplitMode(test_split_mode) if test_split_mode else TestSplitMode.NONE
106
+ self.test_split_ratio = test_split_ratio
107
+ self.val_split_mode = ValSplitMode(val_split_mode)
108
+ self.val_split_ratio = val_split_ratio
109
+ self.image_size = image_size
110
+ self.seed = seed
111
+
112
+ # set transforms
113
+ if bool(train_transform) != bool(eval_transform):
114
+ msg = "Only one of train_transform and eval_transform was specified. This is not recommended because \
115
+ it could lead to unexpected behaviour. Please ensure training and eval transforms have the same \
116
+ reshape and normalization characteristics."
117
+ logger.warning(msg)
118
+ self._train_transform = train_transform or transform
119
+ self._eval_transform = eval_transform or transform
120
+
121
+ self.train_data: AnomalibDataset
122
+ self.val_data: AnomalibDataset
123
+ self.test_data: AnomalibDataset
124
+
125
+ self._samples: DataFrame | None = None
126
+ self._category: str = ""
127
+
128
+ self._is_setup = False # flag to track if setup has been called from the trainer
129
+
130
+ @property
131
+ def name(self) -> str:
132
+ """Name of the datamodule."""
133
+ return self.__class__.__name__
134
+
135
+ def setup(self, stage: str | None = None) -> None:
136
+ """Set up train, validation and test data.
137
+
138
+ Args:
139
+ stage: str | None: Train/Val/Test stages.
140
+ Defaults to ``None``.
141
+ """
142
+ has_subset = any(hasattr(self, subset) for subset in ["train_data", "val_data", "test_data"])
143
+ if not has_subset or not self._is_setup:
144
+ self._setup(stage)
145
+ self._create_test_split()
146
+ self._create_val_split()
147
+ if isinstance(stage, TrainerFn):
148
+ # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer
149
+ self._is_setup = True
150
+
151
+ @abstractmethod
152
+ def _setup(self, _stage: str | None = None) -> None:
153
+ """Set up the datasets and perform dynamic subset splitting.
154
+
155
+ This method may be overridden in subclass for custom splitting behaviour.
156
+
157
+ Note:
158
+ The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule
159
+ subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset
160
+ splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and
161
+ the test set must therefore be created as early as the `fit` stage.
162
+
163
+ """
164
+ raise NotImplementedError
165
+
166
+ @property
167
+ def category(self) -> str:
168
+ """Get the category of the datamodule."""
169
+ return self._category
170
+
171
+ @category.setter
172
+ def category(self, category: str) -> None:
173
+ """Set the category of the datamodule."""
174
+ self._category = category
175
+
176
+ def _create_test_split(self) -> None:
177
+ """Obtain the test set based on the settings in the config."""
178
+ if self.test_data.has_normal:
179
+ # split the test data into normal and anomalous so these can be processed separately
180
+ normal_test_data, self.test_data = split_by_label(self.test_data)
181
+ elif self.test_split_mode != TestSplitMode.NONE:
182
+ # when the user did not provide any normal images for testing, we sample some from the training set,
183
+ # except when the user explicitly requested no test splitting.
184
+ logger.info(
185
+ "No normal test images found. Sampling from training set using a split ratio of %0.2f",
186
+ self.test_split_ratio,
187
+ )
188
+ if self.test_split_ratio is not None:
189
+ self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio, seed=self.seed)
190
+
191
+ if self.test_split_mode == TestSplitMode.FROM_DIR:
192
+ self.test_data += normal_test_data
193
+ elif self.test_split_mode == TestSplitMode.SYNTHETIC:
194
+ self.test_data = SyntheticAnomalyDataset.from_dataset(normal_test_data)
195
+ elif self.test_split_mode != TestSplitMode.NONE:
196
+ msg = f"Unsupported Test Split Mode: {self.test_split_mode}"
197
+ raise ValueError(msg)
198
+
199
+ def _create_val_split(self) -> None:
200
+ """Obtain the validation set based on the settings in the config."""
201
+ if self.val_split_mode == ValSplitMode.FROM_TRAIN:
202
+ # randomly sampled from train set
203
+ self.train_data, self.val_data = random_split(
204
+ self.train_data,
205
+ self.val_split_ratio,
206
+ label_aware=True,
207
+ seed=self.seed,
208
+ )
209
+ elif self.val_split_mode == ValSplitMode.FROM_TEST:
210
+ # randomly sampled from test set
211
+ self.test_data, self.val_data = random_split(
212
+ self.test_data,
213
+ self.val_split_ratio,
214
+ label_aware=True,
215
+ seed=self.seed,
216
+ )
217
+ elif self.val_split_mode == ValSplitMode.SAME_AS_TEST:
218
+ # equal to test set
219
+ self.val_data = self.test_data
220
+ elif self.val_split_mode == ValSplitMode.SYNTHETIC:
221
+ # converted from random training sample
222
+ self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)
223
+ self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
224
+ elif self.val_split_mode == ValSplitMode.FROM_DIR:
225
+ # the val_data is prepared in subclass
226
+ assert hasattr(
227
+ self,
228
+ "val_data",
229
+ ), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup."
230
+ elif self.val_split_mode != ValSplitMode.NONE:
231
+ msg = f"Unknown validation split mode: {self.val_split_mode}"
232
+ raise ValueError(msg)
233
+
234
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
235
+ """Get train dataloader."""
236
+ return DataLoader(
237
+ dataset=self.train_data,
238
+ shuffle=True,
239
+ batch_size=self.train_batch_size,
240
+ num_workers=self.num_workers,
241
+ )
242
+
243
+ def val_dataloader(self) -> EVAL_DATALOADERS:
244
+ """Get validation dataloader."""
245
+ return DataLoader(
246
+ dataset=self.val_data,
247
+ shuffle=False,
248
+ batch_size=self.eval_batch_size,
249
+ num_workers=self.num_workers,
250
+ collate_fn=collate_fn,
251
+ )
252
+
253
+ def test_dataloader(self) -> EVAL_DATALOADERS:
254
+ """Get test dataloader."""
255
+ return DataLoader(
256
+ dataset=self.test_data,
257
+ shuffle=False,
258
+ batch_size=self.eval_batch_size,
259
+ num_workers=self.num_workers,
260
+ collate_fn=collate_fn,
261
+ )
262
+
263
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
264
+ """Use the test dataloader for inference unless overridden."""
265
+ return self.test_dataloader()
266
+
267
+ @property
268
+ def transform(self) -> Transform:
269
+ """Property that returns the user-specified transform for the datamodule, if any.
270
+
271
+ This property is accessed by the engine to set the transform for the model. The eval_transform takes precedence
272
+ over the train_transform, because the transform that we store in the model is the one that should be used during
273
+ inference.
274
+ """
275
+ if self._eval_transform:
276
+ return self._eval_transform
277
+ return None
278
+
279
+ @property
280
+ def train_transform(self) -> Transform:
281
+ """Get the transforms that will be passed to the train dataset.
282
+
283
+ If the train_transform is not set, the engine will request the transform from the model.
284
+ """
285
+ if self._train_transform:
286
+ return self._train_transform
287
+ if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
288
+ return self.trainer.model.transform
289
+ if self.image_size:
290
+ return Resize(self.image_size, antialias=True)
291
+ return None
292
+
293
+ @property
294
+ def eval_transform(self) -> Transform:
295
+ """Get the transform that will be passed to the val/test/predict datasets.
296
+
297
+ If the eval_transform is not set, the engine will request the transform from the model.
298
+ """
299
+ if self._eval_transform:
300
+ return self._eval_transform
301
+ if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
302
+ return self.trainer.model.transform
303
+ if self.image_size:
304
+ return Resize(self.image_size, antialias=True)
305
+ return None
anomalib/data/base/dataset.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib dataset base class."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import copy
7
+ import logging
8
+ from abc import ABC
9
+ from collections.abc import Sequence
10
+ from pathlib import Path
11
+
12
+ import pandas as pd
13
+ import torch
14
+ from pandas import DataFrame
15
+ from torch.utils.data import Dataset
16
+ from torchvision.transforms.v2 import Transform
17
+ from torchvision.tv_tensors import Mask
18
+
19
+ from anomalib import TaskType
20
+ from anomalib.data.utils import LabelName, masks_to_boxes, read_image, read_mask
21
+
22
+ _EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"]
23
+ _EXPECTED_COLUMNS_SEGMENTATION = [*_EXPECTED_COLUMNS_CLASSIFICATION, "mask_path"]
24
+ _EXPECTED_COLUMNS_PERTASK = {
25
+ "classification": _EXPECTED_COLUMNS_CLASSIFICATION,
26
+ "segmentation": _EXPECTED_COLUMNS_SEGMENTATION,
27
+ "detection": _EXPECTED_COLUMNS_SEGMENTATION,
28
+ }
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class AnomalibDataset(Dataset, ABC):
34
+ """Anomalib dataset.
35
+
36
+ The dataset is based on a dataframe that contains the information needed by the dataloader to load each of
37
+ the dataset items into memory.
38
+
39
+ The samples dataframe must be set from the subclass using the setter of the `samples` property.
40
+
41
+ The DataFrame must, at least, include the following columns:
42
+ - `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test').
43
+ - `image_path` (str): Path to the file system location where the image is stored.
44
+ - `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'.
45
+ - `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only).
46
+ Required if task is 'segmentation'.
47
+
48
+ Example DataFrame:
49
+ +---+-------------------+-----------+-------------+------------------+-------+
50
+ | | image_path | label | label_index | mask_path | split |
51
+ +---+-------------------+-----------+-------------+------------------+-------+
52
+ | 0 | path/to/image.png | anomalous | 1 | path/to/mask.png | train |
53
+ +---+-------------------+-----------+-------------+------------------+-------+
54
+
55
+ Note:
56
+ The example above is illustrative and may need to be adjusted based on the specific dataset structure.
57
+
58
+ Args:
59
+ task (str): Task type, either 'classification' or 'segmentation'
60
+ transform (Transform, optional): Transforms that should be applied to the input images.
61
+ Defaults to ``None``.
62
+ """
63
+
64
+ def __init__(self, task: TaskType | str, transform: Transform | None = None) -> None:
65
+ super().__init__()
66
+ self.task = TaskType(task)
67
+ self.transform = transform
68
+ self._samples: DataFrame | None = None
69
+ self._category: str | None = None
70
+
71
+ @property
72
+ def name(self) -> str:
73
+ """Name of the dataset."""
74
+ class_name = self.__class__.__name__
75
+
76
+ # Remove the `_dataset` suffix from the class name
77
+ if class_name.endswith("Dataset"):
78
+ class_name = class_name[:-7]
79
+
80
+ return class_name
81
+
82
+ def __len__(self) -> int:
83
+ """Get length of the dataset."""
84
+ return len(self.samples)
85
+
86
+ def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset":
87
+ """Subsamples the dataset at the provided indices.
88
+
89
+ Args:
90
+ indices (Sequence[int]): Indices at which the dataset is to be subsampled.
91
+ inplace (bool): When true, the subsampling will be performed on the instance itself.
92
+ Defaults to ``False``.
93
+ """
94
+ if len(set(indices)) != len(indices):
95
+ msg = "No duplicates allowed in indices."
96
+ raise ValueError(msg)
97
+ dataset = self if inplace else copy.deepcopy(self)
98
+ dataset.samples = self.samples.iloc[indices].reset_index(drop=True)
99
+ return dataset
100
+
101
+ @property
102
+ def samples(self) -> DataFrame:
103
+ """Get the samples dataframe."""
104
+ if self._samples is None:
105
+ msg = (
106
+ "Dataset does not have a samples dataframe. Ensure that a dataframe has been assigned to "
107
+ "`dataset.samples`."
108
+ )
109
+ raise RuntimeError(msg)
110
+ return self._samples
111
+
112
+ @samples.setter
113
+ def samples(self, samples: DataFrame) -> None:
114
+ """Overwrite the samples with a new dataframe.
115
+
116
+ Args:
117
+ samples (DataFrame): DataFrame with new samples.
118
+ """
119
+ # validate the passed samples by checking the
120
+ if not isinstance(samples, DataFrame):
121
+ msg = f"samples must be a pandas.DataFrame, found {type(samples)}"
122
+ raise TypeError(msg)
123
+
124
+ expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task]
125
+ if not all(col in samples.columns for col in expected_columns):
126
+ msg = f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
127
+ raise ValueError(msg)
128
+
129
+ if not samples["image_path"].apply(lambda p: Path(p).exists()).all():
130
+ msg = "missing file path(s) in samples"
131
+ raise FileNotFoundError(msg)
132
+
133
+ self._samples = samples.sort_values(by="image_path", ignore_index=True)
134
+
135
+ @property
136
+ def category(self) -> str | None:
137
+ """Get the category of the dataset."""
138
+ return self._category
139
+
140
+ @category.setter
141
+ def category(self, category: str) -> None:
142
+ """Set the category of the dataset."""
143
+ self._category = category
144
+
145
+ @property
146
+ def has_normal(self) -> bool:
147
+ """Check if the dataset contains any normal samples."""
148
+ return LabelName.NORMAL in list(self.samples.label_index)
149
+
150
+ @property
151
+ def has_anomalous(self) -> bool:
152
+ """Check if the dataset contains any anomalous samples."""
153
+ return LabelName.ABNORMAL in list(self.samples.label_index)
154
+
155
+ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
156
+ """Get dataset item for the index ``index``.
157
+
158
+ Args:
159
+ index (int): Index to get the item.
160
+
161
+ Returns:
162
+ dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
163
+ target path, image tensor, label and transformed bounding box.
164
+ """
165
+ image_path = self.samples.iloc[index].image_path
166
+ mask_path = self.samples.iloc[index].mask_path
167
+ label_index = self.samples.iloc[index].label_index
168
+
169
+ image = read_image(image_path, as_tensor=True)
170
+ item = {"image_path": image_path, "label": label_index}
171
+
172
+ if self.task == TaskType.CLASSIFICATION:
173
+ item["image"] = self.transform(image) if self.transform else image
174
+ elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
175
+ # Only Anomalous (1) images have masks in anomaly datasets
176
+ # Therefore, create empty mask for Normal (0) images.
177
+ mask = (
178
+ Mask(torch.zeros(image.shape[-2:])).to(torch.uint8)
179
+ if label_index == LabelName.NORMAL
180
+ else read_mask(mask_path, as_tensor=True)
181
+ )
182
+ item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask)
183
+
184
+ if self.task == TaskType.DETECTION:
185
+ # create boxes from masks for detection task
186
+ boxes, _ = masks_to_boxes(item["mask"])
187
+ item["boxes"] = boxes[0]
188
+ else:
189
+ msg = f"Unknown task type: {self.task}"
190
+ raise ValueError(msg)
191
+
192
+ return item
193
+
194
+ def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
195
+ """Concatenate this dataset with another dataset.
196
+
197
+ Args:
198
+ other_dataset (AnomalibDataset): Dataset to concatenate with.
199
+
200
+ Returns:
201
+ AnomalibDataset: Concatenated dataset.
202
+ """
203
+ if not isinstance(other_dataset, self.__class__):
204
+ msg = "Cannot concatenate datasets that are not of the same type."
205
+ raise TypeError(msg)
206
+ dataset = copy.deepcopy(self)
207
+ dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True)
208
+ return dataset
anomalib/data/base/depth.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base Depth Dataset."""
2
+
3
+ # Copyright (C) 2023-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from abc import ABC
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms.functional import to_tensor
11
+ from torchvision.transforms.v2 import Transform
12
+ from torchvision.tv_tensors import Mask
13
+
14
+ from anomalib import TaskType
15
+ from anomalib.data.base.dataset import AnomalibDataset
16
+ from anomalib.data.utils import LabelName, masks_to_boxes, read_depth_image
17
+
18
+
19
+ class AnomalibDepthDataset(AnomalibDataset, ABC):
20
+ """Base depth anomalib dataset class.
21
+
22
+ Args:
23
+ task (str): Task type, either 'classification' or 'segmentation'
24
+ transform (Transform, optional): Transforms that should be applied to the input images.
25
+ Defaults to ``None``.
26
+ """
27
+
28
+ def __init__(self, task: TaskType, transform: Transform | None = None) -> None:
29
+ super().__init__(task, transform)
30
+
31
+ self.transform = transform
32
+
33
+ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
34
+ """Return rgb image, depth image and mask.
35
+
36
+ Args:
37
+ index (int): Index of the item to be returned.
38
+
39
+ Returns:
40
+ dict[str, str | torch.Tensor]: Dictionary containing the image, depth image and mask.
41
+ """
42
+ image_path = self.samples.iloc[index].image_path
43
+ mask_path = self.samples.iloc[index].mask_path
44
+ label_index = self.samples.iloc[index].label_index
45
+ depth_path = self.samples.iloc[index].depth_path
46
+
47
+ image = to_tensor(Image.open(image_path))
48
+ depth_image = to_tensor(read_depth_image(depth_path))
49
+ item = {"image_path": image_path, "depth_path": depth_path, "label": label_index}
50
+
51
+ if self.task == TaskType.CLASSIFICATION:
52
+ item["image"], item["depth_image"] = (
53
+ self.transform(image, depth_image) if self.transform else (image, depth_image)
54
+ )
55
+ elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
56
+ # Only Anomalous (1) images have masks in anomaly datasets
57
+ # Therefore, create empty mask for Normal (0) images.
58
+ mask = (
59
+ Mask(torch.zeros(image.shape[-2:]))
60
+ if label_index == LabelName.NORMAL
61
+ else Mask(to_tensor(Image.open(mask_path)).squeeze())
62
+ )
63
+ item["image"], item["depth_image"], item["mask"] = (
64
+ self.transform(image, depth_image, mask) if self.transform else (image, depth_image, mask)
65
+ )
66
+ item["mask_path"] = mask_path
67
+
68
+ if self.task == TaskType.DETECTION:
69
+ # create boxes from masks for detection task
70
+ boxes, _ = masks_to_boxes(item["mask"])
71
+ item["boxes"] = boxes[0]
72
+ else:
73
+ msg = f"Unknown task type: {self.task}"
74
+ raise ValueError(msg)
75
+
76
+ return item
anomalib/data/base/video.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base Video Dataset."""
2
+
3
+ # Copyright (C) 2023-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from abc import ABC
7
+ from enum import Enum
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import torch
11
+ from pandas import DataFrame
12
+ from torchvision.transforms.v2 import Transform
13
+ from torchvision.transforms.v2.functional import to_dtype_video
14
+ from torchvision.tv_tensors import Mask
15
+
16
+ from anomalib import TaskType
17
+ from anomalib.data.base.datamodule import AnomalibDataModule
18
+ from anomalib.data.base.dataset import AnomalibDataset
19
+ from anomalib.data.utils import ValSplitMode, masks_to_boxes
20
+ from anomalib.data.utils.video import ClipsIndexer
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import Callable
24
+
25
+
26
+ class VideoTargetFrame(str, Enum):
27
+ """Target frame for a video-clip.
28
+
29
+ Used in multi-frame models to determine which frame's ground truth information will be used.
30
+ """
31
+
32
+ FIRST = "first"
33
+ LAST = "last"
34
+ MID = "mid"
35
+ ALL = "all"
36
+
37
+
38
+ class AnomalibVideoDataset(AnomalibDataset, ABC):
39
+ """Base video anomalib dataset class.
40
+
41
+ Args:
42
+ task (str): Task type, either 'classification' or 'segmentation'
43
+ clip_length_in_frames (int): Number of video frames in each clip.
44
+ frames_between_clips (int): Number of frames between each consecutive video clip.
45
+ transform (Transform, optional): Transforms that should be applied to the input clips.
46
+ Defaults to ``None``.
47
+ target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval.
48
+ Defaults to ``VideoTargetFrame.LAST``.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ task: TaskType,
54
+ clip_length_in_frames: int,
55
+ frames_between_clips: int,
56
+ transform: Transform | None = None,
57
+ target_frame: VideoTargetFrame = VideoTargetFrame.LAST,
58
+ ) -> None:
59
+ super().__init__(task, transform)
60
+
61
+ self.clip_length_in_frames = clip_length_in_frames
62
+ self.frames_between_clips = frames_between_clips
63
+ self.transform = transform
64
+
65
+ self.indexer: ClipsIndexer | None = None
66
+ self.indexer_cls: Callable | None = None
67
+
68
+ self.target_frame = target_frame
69
+
70
+ def __len__(self) -> int:
71
+ """Get length of the dataset."""
72
+ if not isinstance(self.indexer, ClipsIndexer):
73
+ msg = "self.indexer must be an instance of ClipsIndexer."
74
+ raise TypeError(msg)
75
+ return self.indexer.num_clips()
76
+
77
+ @property
78
+ def samples(self) -> DataFrame:
79
+ """Get the samples dataframe."""
80
+ return super().samples
81
+
82
+ @samples.setter
83
+ def samples(self, samples: DataFrame) -> None:
84
+ """Overwrite samples and re-index subvideos.
85
+
86
+ Args:
87
+ samples (DataFrame): DataFrame with new samples.
88
+
89
+ Raises:
90
+ ValueError: If the indexer class is not set.
91
+ """
92
+ super(AnomalibVideoDataset, self.__class__).samples.fset(self, samples) # type: ignore[attr-defined]
93
+ self._setup_clips()
94
+
95
+ def _setup_clips(self) -> None:
96
+ """Compute the video and frame indices of the subvideos.
97
+
98
+ Should be called after each change to self._samples
99
+ """
100
+ if not callable(self.indexer_cls):
101
+ msg = "self.indexer_cls must be callable."
102
+ raise TypeError(msg)
103
+ self.indexer = self.indexer_cls( # pylint: disable=not-callable
104
+ video_paths=list(self.samples.image_path),
105
+ mask_paths=list(self.samples.mask_path),
106
+ clip_length_in_frames=self.clip_length_in_frames,
107
+ frames_between_clips=self.frames_between_clips,
108
+ )
109
+
110
+ def _select_targets(self, item: dict[str, Any]) -> dict[str, Any]:
111
+ """Select the target frame from the clip.
112
+
113
+ Args:
114
+ item (dict[str, Any]): Item containing the clip information.
115
+
116
+ Raises:
117
+ ValueError: If the target frame is not one of the supported options.
118
+
119
+ Returns:
120
+ dict[str, Any]: Selected item from the clip.
121
+ """
122
+ if self.target_frame == VideoTargetFrame.FIRST:
123
+ idx = 0
124
+ elif self.target_frame == VideoTargetFrame.LAST:
125
+ idx = -1
126
+ elif self.target_frame == VideoTargetFrame.MID:
127
+ idx = int(self.clip_length_in_frames / 2)
128
+ else:
129
+ msg = f"Unknown video target frame: {self.target_frame}"
130
+ raise ValueError(msg)
131
+
132
+ if item.get("mask") is not None:
133
+ item["mask"] = item["mask"][idx, ...]
134
+ if item.get("boxes") is not None:
135
+ item["boxes"] = item["boxes"][idx]
136
+ if item.get("label") is not None:
137
+ item["label"] = item["label"][idx]
138
+ if item.get("original_image") is not None:
139
+ item["original_image"] = item["original_image"][idx]
140
+ if item.get("frames") is not None:
141
+ item["frames"] = item["frames"][idx]
142
+ return item
143
+
144
+ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
145
+ """Get the dataset item for the index ``index``.
146
+
147
+ Args:
148
+ index (int): Index of the item to be returned.
149
+
150
+ Returns:
151
+ dict[str, str | torch.Tensor]: Dictionary containing the mask, clip and file system information.
152
+ """
153
+ if not isinstance(self.indexer, ClipsIndexer):
154
+ msg = "self.indexer must be an instance of ClipsIndexer."
155
+ raise TypeError(msg)
156
+ item = self.indexer.get_item(index)
157
+ item["image"] = to_dtype_video(video=item["image"], scale=True)
158
+ # include the untransformed image for visualization
159
+ item["original_image"] = item["image"].to(torch.uint8)
160
+
161
+ # apply transforms
162
+ if item.get("mask") is not None:
163
+ if self.transform:
164
+ item["image"], item["mask"] = self.transform(item["image"], Mask(item["mask"]))
165
+ item["label"] = torch.Tensor([1 in frame for frame in item["mask"]]).int().squeeze(0)
166
+ if self.task == TaskType.DETECTION:
167
+ item["boxes"], _ = masks_to_boxes(item["mask"])
168
+ item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
169
+ elif self.transform:
170
+ item["image"] = self.transform(item["image"])
171
+
172
+ # squeeze temporal dimensions in case clip length is 1
173
+ item["image"] = item["image"].squeeze(0)
174
+
175
+ # include only target frame in gt
176
+ if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL:
177
+ item = self._select_targets(item)
178
+
179
+ if item["mask"] is None:
180
+ item.pop("mask")
181
+
182
+ return item
183
+
184
+
185
+ class AnomalibVideoDataModule(AnomalibDataModule):
186
+ """Base class for video data modules."""
187
+
188
+ def _create_test_split(self) -> None:
189
+ """Video datamodules do not support dynamic assignment of the test split."""
190
+
191
+ def _setup(self, _stage: str | None = None) -> None:
192
+ """Set up the datasets and perform dynamic subset splitting.
193
+
194
+ This method may be overridden in subclass for custom splitting behaviour.
195
+
196
+ Video datamodules are not compatible with synthetic anomaly generation.
197
+ """
198
+ if self.train_data is None:
199
+ msg = "self.train_data cannot be None."
200
+ raise ValueError(msg)
201
+
202
+ if self.test_data is None:
203
+ msg = "self.test_data cannot be None."
204
+ raise ValueError(msg)
205
+
206
+ self.train_data.setup()
207
+ self.test_data.setup()
208
+
209
+ if self.val_split_mode == ValSplitMode.SYNTHETIC:
210
+ msg = f"Val split mode {self.test_split_mode} not supported for video datasets."
211
+ raise ValueError(msg)
212
+
213
+ self._create_val_split()
anomalib/data/depth/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib Depth Datasets."""
2
+
3
+ # Copyright (C) 2023 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from enum import Enum
8
+
9
+ from .folder_3d import Folder3D
10
+ from .mvtec_3d import MVTec3D
11
+
12
+
13
+ class DepthDataFormat(str, Enum):
14
+ """Supported Depth Dataset Types."""
15
+
16
+ MVTEC_3D = "mvtec_3d"
17
+ FOLDER_3D = "folder_3d"
18
+
19
+
20
+ __all__ = ["Folder3D", "MVTec3D"]
anomalib/data/depth/folder_3d.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Folder Dataset.
2
+
3
+ This script creates a custom dataset from a folder.
4
+ """
5
+
6
+ # Copyright (C) 2022 Intel Corporation
7
+ # SPDX-License-Identifier: Apache-2.0
8
+
9
+
10
+ from pathlib import Path
11
+
12
+ from pandas import DataFrame, isna
13
+ from torchvision.transforms.v2 import Transform
14
+
15
+ from anomalib import TaskType
16
+ from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
17
+ from anomalib.data.errors import MisMatchError
18
+ from anomalib.data.utils import (
19
+ DirType,
20
+ LabelName,
21
+ Split,
22
+ TestSplitMode,
23
+ ValSplitMode,
24
+ )
25
+ from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path
26
+
27
+
28
+ def make_folder3d_dataset( # noqa: C901
29
+ normal_dir: str | Path,
30
+ root: str | Path | None = None,
31
+ abnormal_dir: str | Path | None = None,
32
+ normal_test_dir: str | Path | None = None,
33
+ mask_dir: str | Path | None = None,
34
+ normal_depth_dir: str | Path | None = None,
35
+ abnormal_depth_dir: str | Path | None = None,
36
+ normal_test_depth_dir: str | Path | None = None,
37
+ split: str | Split | None = None,
38
+ extensions: tuple[str, ...] | None = None,
39
+ ) -> DataFrame:
40
+ """Make Folder Dataset.
41
+
42
+ Args:
43
+ normal_dir (str | Path): Path to the directory containing normal images.
44
+ root (str | Path | None): Path to the root directory of the dataset.
45
+ Defaults to ``None``.
46
+ abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images.
47
+ Defaults to ``None``.
48
+ normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test
49
+ dataset. Normal test images will be a split of `normal_dir` if `None`.
50
+ Defaults to ``None``.
51
+ mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations.
52
+ Defaults to ``None``.
53
+ normal_depth_dir (str | Path | None, optional): Path to the directory containing
54
+ normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
55
+ Defaults to ``None``.
56
+ abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for
57
+ the test dataset.
58
+ Defaults to ``None``.
59
+ normal_test_depth_dir (str | Path | None, optional): Path to the directory containing normal depth images for
60
+ the test dataset. Normal test images will be a split of `normal_dir` if `None`.
61
+ Defaults to ``None``.
62
+ split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST).
63
+ Defaults to ``None``.
64
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
65
+ Defaults to ``None``.
66
+
67
+ Returns:
68
+ DataFrame: an output dataframe containing samples for the requested split (ie., train or test)
69
+ """
70
+ normal_dir = validate_and_resolve_path(normal_dir, root)
71
+ abnormal_dir = validate_and_resolve_path(abnormal_dir, root) if abnormal_dir else None
72
+ normal_test_dir = validate_and_resolve_path(normal_test_dir, root) if normal_test_dir else None
73
+ mask_dir = validate_and_resolve_path(mask_dir, root) if mask_dir else None
74
+ normal_depth_dir = validate_and_resolve_path(normal_depth_dir, root) if normal_depth_dir else None
75
+ abnormal_depth_dir = validate_and_resolve_path(abnormal_depth_dir, root) if abnormal_depth_dir else None
76
+ normal_test_depth_dir = validate_and_resolve_path(normal_test_depth_dir, root) if normal_test_depth_dir else None
77
+
78
+ if not normal_dir.is_dir():
79
+ msg = "A folder location must be provided in normal_dir."
80
+ raise ValueError(msg)
81
+
82
+ filenames = []
83
+ labels = []
84
+ dirs = {DirType.NORMAL: normal_dir}
85
+
86
+ if abnormal_dir:
87
+ dirs[DirType.ABNORMAL] = abnormal_dir
88
+
89
+ if normal_test_dir:
90
+ dirs[DirType.NORMAL_TEST] = normal_test_dir
91
+
92
+ if normal_depth_dir:
93
+ dirs[DirType.NORMAL_DEPTH] = normal_depth_dir
94
+
95
+ if abnormal_depth_dir:
96
+ dirs[DirType.ABNORMAL_DEPTH] = abnormal_depth_dir
97
+
98
+ if normal_test_depth_dir:
99
+ dirs[DirType.NORMAL_TEST_DEPTH] = normal_test_depth_dir
100
+
101
+ if mask_dir:
102
+ dirs[DirType.MASK] = mask_dir
103
+
104
+ for dir_type, path in dirs.items():
105
+ filename, label = _prepare_files_labels(path, dir_type, extensions)
106
+ filenames += filename
107
+ labels += label
108
+
109
+ samples = DataFrame({"image_path": filenames, "label": labels})
110
+ samples = samples.sort_values(by="image_path", ignore_index=True)
111
+
112
+ # Create label index for normal (0) and abnormal (1) images.
113
+ samples.loc[
114
+ (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST),
115
+ "label_index",
116
+ ] = LabelName.NORMAL
117
+ samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
118
+ samples.label_index = samples.label_index.astype("Int64")
119
+
120
+ # If a path to mask is provided, add it to the sample dataframe.
121
+ if normal_depth_dir:
122
+ samples.loc[samples.label == DirType.NORMAL, "depth_path"] = samples.loc[
123
+ samples.label == DirType.NORMAL_DEPTH
124
+ ].image_path.to_numpy()
125
+ samples.loc[samples.label == DirType.ABNORMAL, "depth_path"] = samples.loc[
126
+ samples.label == DirType.ABNORMAL_DEPTH
127
+ ].image_path.to_numpy()
128
+
129
+ if normal_test_dir:
130
+ samples.loc[samples.label == DirType.NORMAL_TEST, "depth_path"] = samples.loc[
131
+ samples.label == DirType.NORMAL_TEST_DEPTH
132
+ ].image_path.to_numpy()
133
+
134
+ # make sure every rgb image has a corresponding depth image and that the file exists
135
+ mismatch = (
136
+ samples.loc[samples.label_index == LabelName.ABNORMAL]
137
+ .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
138
+ .all()
139
+ )
140
+ if not mismatch:
141
+ msg = """Mismatch between anomalous images and depth images. Make sure the mask files
142
+ in 'xyz' folder follow the same naming convention as the anomalous images in the dataset
143
+ (e.g. image: '000.png', depth: '000.tiff')."""
144
+ raise MisMatchError(msg)
145
+
146
+ missing_depth_files = samples.depth_path.apply(
147
+ lambda x: Path(x).exists() if not isna(x) else True,
148
+ ).all()
149
+ if not missing_depth_files:
150
+ msg = "Missing depth image files."
151
+ raise FileNotFoundError(msg)
152
+
153
+ samples = samples.astype({"depth_path": "str"})
154
+
155
+ # If a path to mask is provided, add it to the sample dataframe.
156
+ if mask_dir and abnormal_dir:
157
+ samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
158
+ samples.label == DirType.MASK
159
+ ].image_path.to_numpy()
160
+ samples["mask_path"] = samples["mask_path"].fillna("")
161
+ samples = samples.astype({"mask_path": "str"})
162
+
163
+ # make sure all the files exist
164
+ if not samples.mask_path.apply(
165
+ lambda x: Path(x).exists() if x != "" else True,
166
+ ).all():
167
+ msg = f"Missing mask files. mask_dir={mask_dir}"
168
+ raise FileNotFoundError(msg)
169
+ else:
170
+ samples["mask_path"] = ""
171
+
172
+ # remove all the rows with temporal image samples that have already been assigned
173
+ samples = samples.loc[
174
+ (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
175
+ ]
176
+
177
+ # Ensure the pathlib objects are converted to str.
178
+ # This is because torch dataloader doesn't like pathlib.
179
+ samples = samples.astype({"image_path": "str"})
180
+
181
+ # Create train/test split.
182
+ # By default, all the normal samples are assigned as train.
183
+ # and all the abnormal samples are test.
184
+ samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
185
+ samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST
186
+
187
+ # Get the data frame for the split.
188
+ if split:
189
+ samples = samples[samples.split == split]
190
+ samples = samples.reset_index(drop=True)
191
+
192
+ return samples
193
+
194
+
195
+ class Folder3DDataset(AnomalibDepthDataset):
196
+ """Folder dataset.
197
+
198
+ Args:
199
+ name (str): Name of the dataset.
200
+ task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
201
+ transform (Transform): Transforms that should be applied to the input images.
202
+ normal_dir (str | Path): Path to the directory containing normal images.
203
+ root (str | Path | None): Root folder of the dataset.
204
+ Defaults to ``None``.
205
+ abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images.
206
+ Defaults to ``None``.
207
+ normal_test_dir (str | Path | None, optional): Path to the directory containing
208
+ normal images for the test dataset.
209
+ Defaults to ``None``.
210
+ mask_dir (str | Path | None, optional): Path to the directory containing
211
+ the mask annotations.
212
+ Defaults to ``None``.
213
+ normal_depth_dir (str | Path | None, optional): Path to the directory containing
214
+ normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
215
+ Defaults to ``None``.
216
+ abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for
217
+ the test dataset.
218
+ Defaults to ``None``.
219
+ normal_test_depth_dir (str | Path | None, optional): Path to the directory containing
220
+ normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`.
221
+ Defaults to ``None``.
222
+ transform (Transform, optional): Transforms that should be applied to the input images.
223
+ Defaults to ``None``.
224
+ split (str | Split | None): Fixed subset split that follows from folder structure on file system.
225
+ Choose from [Split.FULL, Split.TRAIN, Split.TEST]
226
+ Defaults to ``None``.
227
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
228
+ Defaults to ``None``.
229
+
230
+ Raises:
231
+ ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is
232
+ provided, `task` should be set to `segmentation`.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ name: str,
238
+ task: TaskType,
239
+ normal_dir: str | Path,
240
+ root: str | Path | None = None,
241
+ abnormal_dir: str | Path | None = None,
242
+ normal_test_dir: str | Path | None = None,
243
+ mask_dir: str | Path | None = None,
244
+ normal_depth_dir: str | Path | None = None,
245
+ abnormal_depth_dir: str | Path | None = None,
246
+ normal_test_depth_dir: str | Path | None = None,
247
+ transform: Transform | None = None,
248
+ split: str | Split | None = None,
249
+ extensions: tuple[str, ...] | None = None,
250
+ ) -> None:
251
+ super().__init__(task, transform)
252
+
253
+ self._name = name
254
+ self.split = split
255
+ self.root = root
256
+ self.normal_dir = normal_dir
257
+ self.abnormal_dir = abnormal_dir
258
+ self.normal_test_dir = normal_test_dir
259
+ self.mask_dir = mask_dir
260
+ self.normal_depth_dir = normal_depth_dir
261
+ self.abnormal_depth_dir = abnormal_depth_dir
262
+ self.normal_test_depth_dir = normal_test_depth_dir
263
+ self.extensions = extensions
264
+
265
+ self.samples = make_folder3d_dataset(
266
+ root=self.root,
267
+ normal_dir=self.normal_dir,
268
+ abnormal_dir=self.abnormal_dir,
269
+ normal_test_dir=self.normal_test_dir,
270
+ mask_dir=self.mask_dir,
271
+ normal_depth_dir=self.normal_depth_dir,
272
+ abnormal_depth_dir=self.abnormal_depth_dir,
273
+ normal_test_depth_dir=self.normal_test_depth_dir,
274
+ split=self.split,
275
+ extensions=self.extensions,
276
+ )
277
+
278
+ @property
279
+ def name(self) -> str:
280
+ """Name of the dataset.
281
+
282
+ Folder3D dataset overrides the name property to provide a custom name.
283
+ """
284
+ return self._name
285
+
286
+
287
+ class Folder3D(AnomalibDataModule):
288
+ """Folder DataModule.
289
+
290
+ Args:
291
+ name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
292
+ normal_dir (str | Path): Name of the directory containing normal images.
293
+ root (str | Path | None): Path to the root folder containing normal and abnormal dirs.
294
+ Defaults to ``None``.
295
+ abnormal_dir (str | Path | None): Name of the directory containing abnormal images.
296
+ Defaults to ``abnormal``.
297
+ normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test
298
+ dataset.
299
+ Defaults to ``None``.
300
+ mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations.
301
+ Defaults to ``None``.
302
+ normal_depth_dir (str | Path | None, optional): Path to the directory containing
303
+ normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
304
+ abnormal_depth_dir (str | Path | None, optional): Path to the directory containing
305
+ abnormal depth images for the test dataset.
306
+ normal_test_depth_dir (str | Path | None, optional): Path to the directory containing
307
+ normal depth images for the test dataset. Normal test images will be a split of `normal_dir`
308
+ if `None`. Defaults to None.
309
+ normal_split_ratio (float, optional): Ratio to split normal training images and add to the
310
+ test set in case test set doesn't contain any normal images.
311
+ Defaults to 0.2.
312
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the
313
+ directory. Defaults to None.
314
+ train_batch_size (int, optional): Training batch size.
315
+ Defaults to ``32``.
316
+ eval_batch_size (int, optional): Test batch size.
317
+ Defaults to ``32``.
318
+ num_workers (int, optional): Number of workers.
319
+ Defaults to ``8``.
320
+ task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``.
321
+ Defaults to ``TaskType.SEGMENTATION``.
322
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
323
+ Defaults to ``None``.
324
+ transform (Transform, optional): Transforms that should be applied to the input images.
325
+ Defaults to ``None``.
326
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
327
+ Defaults to ``None``.
328
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
329
+ Defaults to ``None``.
330
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
331
+ Defaults to ``TestSplitMode.FROM_DIR``.
332
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
333
+ Defaults to ``0.2``.
334
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
335
+ Defaults to ``ValSplitMode.FROM_TEST``.
336
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
337
+ Defaults to ``0.5``.
338
+ seed (int | None, optional): Seed used during random subset splitting.
339
+ Defaults to ``None``.
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ name: str,
345
+ normal_dir: str | Path,
346
+ root: str | Path,
347
+ abnormal_dir: str | Path | None = None,
348
+ normal_test_dir: str | Path | None = None,
349
+ mask_dir: str | Path | None = None,
350
+ normal_depth_dir: str | Path | None = None,
351
+ abnormal_depth_dir: str | Path | None = None,
352
+ normal_test_depth_dir: str | Path | None = None,
353
+ extensions: tuple[str] | None = None,
354
+ train_batch_size: int = 32,
355
+ eval_batch_size: int = 32,
356
+ num_workers: int = 8,
357
+ task: TaskType | str = TaskType.SEGMENTATION,
358
+ image_size: tuple[int, int] | None = None,
359
+ transform: Transform | None = None,
360
+ train_transform: Transform | None = None,
361
+ eval_transform: Transform | None = None,
362
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
363
+ test_split_ratio: float = 0.2,
364
+ val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
365
+ val_split_ratio: float = 0.5,
366
+ seed: int | None = None,
367
+ ) -> None:
368
+ super().__init__(
369
+ train_batch_size=train_batch_size,
370
+ eval_batch_size=eval_batch_size,
371
+ num_workers=num_workers,
372
+ image_size=image_size,
373
+ transform=transform,
374
+ train_transform=train_transform,
375
+ eval_transform=eval_transform,
376
+ test_split_mode=test_split_mode,
377
+ test_split_ratio=test_split_ratio,
378
+ val_split_mode=val_split_mode,
379
+ val_split_ratio=val_split_ratio,
380
+ seed=seed,
381
+ )
382
+ self._name = name
383
+ self.task = TaskType(task)
384
+ self.root = Path(root)
385
+ self.normal_dir = normal_dir
386
+ self.abnormal_dir = abnormal_dir
387
+ self.normal_test_dir = normal_test_dir
388
+ self.mask_dir = mask_dir
389
+ self.normal_depth_dir = normal_depth_dir
390
+ self.abnormal_depth_dir = abnormal_depth_dir
391
+ self.normal_test_depth_dir = normal_test_depth_dir
392
+ self.extensions = extensions
393
+
394
+ def _setup(self, _stage: str | None = None) -> None:
395
+ self.train_data = Folder3DDataset(
396
+ name=self.name,
397
+ task=self.task,
398
+ transform=self.train_transform,
399
+ split=Split.TRAIN,
400
+ root=self.root,
401
+ normal_dir=self.normal_dir,
402
+ abnormal_dir=self.abnormal_dir,
403
+ normal_test_dir=self.normal_test_dir,
404
+ mask_dir=self.mask_dir,
405
+ normal_depth_dir=self.normal_depth_dir,
406
+ abnormal_depth_dir=self.abnormal_depth_dir,
407
+ normal_test_depth_dir=self.normal_test_depth_dir,
408
+ extensions=self.extensions,
409
+ )
410
+
411
+ self.test_data = Folder3DDataset(
412
+ name=self.name,
413
+ task=self.task,
414
+ transform=self.eval_transform,
415
+ split=Split.TEST,
416
+ root=self.root,
417
+ normal_dir=self.normal_dir,
418
+ abnormal_dir=self.abnormal_dir,
419
+ normal_test_dir=self.normal_test_dir,
420
+ normal_depth_dir=self.normal_depth_dir,
421
+ abnormal_depth_dir=self.abnormal_depth_dir,
422
+ normal_test_depth_dir=self.normal_test_depth_dir,
423
+ mask_dir=self.mask_dir,
424
+ extensions=self.extensions,
425
+ )
426
+
427
+ @property
428
+ def name(self) -> str:
429
+ """Name of the datamodule.
430
+
431
+ Folder3D datamodule overrides the name property to provide a custom name.
432
+ """
433
+ return self._name
anomalib/data/depth/mvtec_3d.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MVTec 3D-AD Dataset (CC BY-NC-SA 4.0).
2
+
3
+ Description:
4
+ This script contains PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for the MVTec 3D-AD dataset.
5
+ If the dataset is not on the file system, the script downloads and extracts the dataset and create PyTorch data
6
+ objects.
7
+
8
+ License:
9
+ MVTec 3D-AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
10
+ License (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
11
+
12
+ Reference:
13
+ - Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly
14
+ Detection and Localization in: Proceedings of the 17th International Joint Conference on Computer Vision,
15
+ Imaging and Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022, DOI: 10.5220/
16
+ 0010865000003124.
17
+ """
18
+
19
+ # Copyright (C) 2022 Intel Corporation
20
+ # SPDX-License-Identifier: Apache-2.0
21
+
22
+
23
+ import logging
24
+ from collections.abc import Sequence
25
+ from pathlib import Path
26
+
27
+ from pandas import DataFrame
28
+ from torchvision.transforms.v2 import Transform
29
+
30
+ from anomalib import TaskType
31
+ from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
32
+ from anomalib.data.errors import MisMatchError
33
+ from anomalib.data.utils import (
34
+ DownloadInfo,
35
+ LabelName,
36
+ Split,
37
+ TestSplitMode,
38
+ ValSplitMode,
39
+ download_and_extract,
40
+ validate_path,
41
+ )
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ IMG_EXTENSIONS = [".png", ".PNG", ".tiff"]
47
+
48
+ DOWNLOAD_INFO = DownloadInfo(
49
+ name="mvtec_3d",
50
+ url="https://www.mydrive.ch/shares/45920/dd1eb345346df066c63b5c95676b961b/download/428824485-1643285832"
51
+ "/mvtec_3d_anomaly_detection.tar.xz",
52
+ hashsum="d8bb2800fbf3ac88e798da6ae10dc819",
53
+ )
54
+
55
+ CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire")
56
+
57
+
58
+ def make_mvtec_3d_dataset(
59
+ root: str | Path,
60
+ split: str | Split | None = None,
61
+ extensions: Sequence[str] | None = None,
62
+ ) -> DataFrame:
63
+ """Create MVTec 3D-AD samples by parsing the MVTec AD data file structure.
64
+
65
+ The files are expected to follow this structure:
66
+ - `path/to/dataset/split/category/image_filename.png`
67
+ - `path/to/dataset/ground_truth/category/mask_filename.png`
68
+
69
+ This function creates a DataFrame to store the parsed information. The DataFrame follows this format:
70
+
71
+ +---+---------------+-------+---------+---------------+---------------------------------------+-------------+
72
+ | | path | split | label | image_path | mask_path | label_index |
73
+ +---+---------------+-------+---------+---------------+---------------------------------------+-------------+
74
+ | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 |
75
+ +---+---------------+-------+---------+---------------+---------------------------------------+-------------+
76
+
77
+ Args:
78
+ root (Path): Path to the dataset.
79
+ split (str | Split | None, optional): Dataset split (e.g., 'train' or 'test').
80
+ Defaults to ``None``.
81
+ extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset.
82
+ Defaults to ``None``.
83
+
84
+ Examples:
85
+ The following example shows how to get training samples from the MVTec 3D-AD 'bagel' category:
86
+
87
+ >>> from pathlib import Path
88
+ >>> root = Path('./MVTec3D')
89
+ >>> category = 'bagel'
90
+ >>> path = root / category
91
+ >>> print(path)
92
+ PosixPath('MVTec3D/bagel')
93
+
94
+ >>> samples = create_mvtec_3d_ad_samples(path, split='train')
95
+ >>> print(samples.head())
96
+ path split label image_path mask_path label_index
97
+ MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/105.png MVTec3D/bagel/ground_truth/good/gt/105.png 0
98
+ MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/017.png MVTec3D/bagel/ground_truth/good/gt/017.png 0
99
+
100
+ Returns:
101
+ DataFrame: An output DataFrame containing the samples of the dataset.
102
+ """
103
+ if extensions is None:
104
+ extensions = IMG_EXTENSIONS
105
+
106
+ root = validate_path(root)
107
+ samples_list = [(str(root),) + f.parts[-4:] for f in root.glob(r"**/*") if f.suffix in extensions]
108
+ if not samples_list:
109
+ msg = f"Found 0 images in {root}"
110
+ raise RuntimeError(msg)
111
+
112
+ samples = DataFrame(samples_list, columns=["path", "split", "label", "type", "file_name"])
113
+
114
+ # Modify image_path column by converting to absolute path
115
+ samples.loc[(samples.type == "rgb"), "image_path"] = (
116
+ samples.path + "/" + samples.split + "/" + samples.label + "/" + "rgb/" + samples.file_name
117
+ )
118
+ samples.loc[(samples.type == "rgb"), "depth_path"] = (
119
+ samples.path
120
+ + "/"
121
+ + samples.split
122
+ + "/"
123
+ + samples.label
124
+ + "/"
125
+ + "xyz/"
126
+ + samples.file_name.str.split(".").str[0]
127
+ + ".tiff"
128
+ )
129
+
130
+ # Create label index for normal (0) and anomalous (1) images.
131
+ samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
132
+ samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
133
+ samples.label_index = samples.label_index.astype(int)
134
+
135
+ # separate masks from samples
136
+ mask_samples = samples.loc[((samples.split == "test") & (samples.type == "rgb"))].sort_values(
137
+ by="image_path",
138
+ ignore_index=True,
139
+ )
140
+ samples = samples.sort_values(by="image_path", ignore_index=True)
141
+
142
+ # assign mask paths to all test images
143
+ samples.loc[((samples.split == "test") & (samples.type == "rgb")), "mask_path"] = (
144
+ mask_samples.path + "/" + samples.split + "/" + samples.label + "/" + "gt/" + samples.file_name
145
+ )
146
+ samples = samples.dropna(subset=["image_path"])
147
+ samples = samples.astype({"image_path": "str", "mask_path": "str", "depth_path": "str"})
148
+
149
+ # assert that the right mask files are associated with the right test images
150
+ mismatch_masks = (
151
+ samples.loc[samples.label_index == LabelName.ABNORMAL]
152
+ .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
153
+ .all()
154
+ )
155
+ if not mismatch_masks:
156
+ msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
157
+ in 'ground_truth' folder follow the same naming convention as the anomalous images in
158
+ the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
159
+ raise MisMatchError(msg)
160
+
161
+ mismatch_depth = (
162
+ samples.loc[samples.label_index == LabelName.ABNORMAL]
163
+ .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
164
+ .all()
165
+ )
166
+ if not mismatch_depth:
167
+ msg = """Mismatch between anomalous images and depth images. Make sure the mask files in
168
+ 'xyz' folder follow the same naming convention as the anomalous images in the dataset
169
+ (e.g. image: '000.png', depth: '000.tiff')."""
170
+ raise MisMatchError(msg)
171
+
172
+ if split:
173
+ samples = samples[samples.split == split].reset_index(drop=True)
174
+
175
+ return samples
176
+
177
+
178
+ class MVTec3DDataset(AnomalibDepthDataset):
179
+ """MVTec 3D dataset class.
180
+
181
+ Args:
182
+ task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
183
+ root (Path | str): Path to the root of the dataset
184
+ Defaults to ``"./datasets/MVTec3D"``.
185
+ category (str): Sub-category of the dataset, e.g. 'bagel'
186
+ Defaults to ``"bagel"``.
187
+ transform (Transform, optional): Transforms that should be applied to the input images.
188
+ Defaults to ``None``.
189
+ split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
190
+ Defaults to ``None``.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ task: TaskType,
196
+ root: Path | str = "./datasets/MVTec3D",
197
+ category: str = "bagel",
198
+ transform: Transform | None = None,
199
+ split: str | Split | None = None,
200
+ ) -> None:
201
+ super().__init__(task=task, transform=transform)
202
+
203
+ self.root_category = Path(root) / Path(category)
204
+ self.split = split
205
+ self.samples = make_mvtec_3d_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS)
206
+
207
+
208
+ class MVTec3D(AnomalibDataModule):
209
+ """MVTec Datamodule.
210
+
211
+ Args:
212
+ root (Path | str): Path to the root of the dataset
213
+ Defaults to ``"./datasets/MVTec3D"``.
214
+ category (str): Category of the MVTec dataset (e.g. "bottle" or "cable").
215
+ Defaults to ``bagel``.
216
+ train_batch_size (int, optional): Training batch size.
217
+ Defaults to ``32``.
218
+ eval_batch_size (int, optional): Test batch size.
219
+ Defaults to ``32``.
220
+ num_workers (int, optional): Number of workers.
221
+ Defaults to ``8``.
222
+ task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
223
+ Defaults to ``TaskType.SEGMENTATION``.
224
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
225
+ Defaults to ``None``.
226
+ transform (Transform, optional): Transforms that should be applied to the input images.
227
+ Defaults to ``None``.
228
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
229
+ Defaults to ``None``.
230
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
231
+ Defaults to ``None``.
232
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
233
+ Defaults to ``TestSplitMode.FROM_DIR``.
234
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
235
+ Defaults to ``0.2``.
236
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
237
+ Defaults to ``ValSplitMode.SAME_AS_TEST``.
238
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
239
+ Defaults to ``0.5``.
240
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
241
+ Defaults to ``None``.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ root: Path | str = "./datasets/MVTec3D",
247
+ category: str = "bagel",
248
+ train_batch_size: int = 32,
249
+ eval_batch_size: int = 32,
250
+ num_workers: int = 8,
251
+ task: TaskType | str = TaskType.SEGMENTATION,
252
+ image_size: tuple[int, int] | None = None,
253
+ transform: Transform | None = None,
254
+ train_transform: Transform | None = None,
255
+ eval_transform: Transform | None = None,
256
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
257
+ test_split_ratio: float = 0.2,
258
+ val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
259
+ val_split_ratio: float = 0.5,
260
+ seed: int | None = None,
261
+ ) -> None:
262
+ super().__init__(
263
+ train_batch_size=train_batch_size,
264
+ eval_batch_size=eval_batch_size,
265
+ num_workers=num_workers,
266
+ image_size=image_size,
267
+ transform=transform,
268
+ train_transform=train_transform,
269
+ eval_transform=eval_transform,
270
+ test_split_mode=test_split_mode,
271
+ test_split_ratio=test_split_ratio,
272
+ val_split_mode=val_split_mode,
273
+ val_split_ratio=val_split_ratio,
274
+ seed=seed,
275
+ )
276
+
277
+ self.task = TaskType(task)
278
+ self.root = Path(root)
279
+ self.category = category
280
+
281
+ def _setup(self, _stage: str | None = None) -> None:
282
+ self.train_data = MVTec3DDataset(
283
+ task=self.task,
284
+ transform=self.train_transform,
285
+ split=Split.TRAIN,
286
+ root=self.root,
287
+ category=self.category,
288
+ )
289
+ self.test_data = MVTec3DDataset(
290
+ task=self.task,
291
+ transform=self.eval_transform,
292
+ split=Split.TEST,
293
+ root=self.root,
294
+ category=self.category,
295
+ )
296
+
297
+ def prepare_data(self) -> None:
298
+ """Download the dataset if not available."""
299
+ if (self.root / self.category).is_dir():
300
+ logger.info("Found the dataset.")
301
+ else:
302
+ download_and_extract(self.root, DOWNLOAD_INFO)
anomalib/data/errors.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Exception Class for Mismatch Detection (MisMatchError)."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ class MisMatchError(Exception):
8
+ """Exception raised when a mismatch is detected.
9
+
10
+ Attributes:
11
+ message (str): Explanation of the error.
12
+ """
13
+
14
+ def __init__(self, message: str = "") -> None:
15
+ if message:
16
+ self.message = message
17
+ else:
18
+ self.message = "Mismatch detected."
19
+ super().__init__(self.message)
anomalib/data/image/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anomalib Image Datasets.
2
+
3
+ This module contains the supported image datasets for Anomalib.
4
+ """
5
+
6
+ # Copyright (C) 2024 Intel Corporation
7
+ # SPDX-License-Identifier: Apache-2.0
8
+
9
+
10
+ from enum import Enum
11
+
12
+ from .btech import BTech
13
+ from .folder import Folder
14
+ from .kolektor import Kolektor
15
+ from .mvtec import MVTec
16
+ from .mvtec_loco import MVTecLoco
17
+ from .visa import Visa
18
+
19
+
20
+ class ImageDataFormat(str, Enum):
21
+ """Supported Image Dataset Types."""
22
+
23
+ MVTEC = "mvtec"
24
+ MVTEC_3D = "mvtec_3d"
25
+ MVTEC_LOCO = "mvtec_loco"
26
+ BTECH = "btech"
27
+ KOLEKTOR = "kolektor"
28
+ FOLDER = "folder"
29
+ FOLDER_3D = "folder_3d"
30
+ VISA = "visa"
31
+
32
+
33
+ __all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"]
anomalib/data/image/btech.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BTech Dataset.
2
+
3
+ This script contains PyTorch Lightning DataModule for the BTech dataset.
4
+
5
+ If the dataset is not on the file system, the script downloads and
6
+ extracts the dataset and create PyTorch data objects.
7
+ """
8
+
9
+ # Copyright (C) 2022-2024 Intel Corporation
10
+ # SPDX-License-Identifier: Apache-2.0
11
+
12
+ import logging
13
+ import shutil
14
+ from pathlib import Path
15
+
16
+ import cv2
17
+ import pandas as pd
18
+ from pandas.core.frame import DataFrame
19
+ from torchvision.transforms.v2 import Transform
20
+ from tqdm import tqdm
21
+
22
+ from anomalib import TaskType
23
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
24
+ from anomalib.data.utils import (
25
+ DownloadInfo,
26
+ LabelName,
27
+ Split,
28
+ TestSplitMode,
29
+ ValSplitMode,
30
+ download_and_extract,
31
+ validate_path,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ DOWNLOAD_INFO = DownloadInfo(
37
+ name="btech",
38
+ url="https://avires.dimi.uniud.it/papers/btad/btad.zip",
39
+ hashsum="461c9387e515bfed41ecaae07c50cf6b10def647b36c9e31d239ab2736b10d2a",
40
+ )
41
+
42
+ CATEGORIES = ("01", "02", "03")
43
+
44
+
45
+ def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
46
+ """Create BTech samples by parsing the BTech data file structure.
47
+
48
+ The files are expected to follow the structure:
49
+
50
+ .. code-block:: bash
51
+
52
+ path/to/dataset/split/category/image_filename.png
53
+ path/to/dataset/ground_truth/category/mask_filename.png
54
+
55
+ Args:
56
+ path (Path): Path to dataset
57
+ split (str | Split | None, optional): Dataset split (ie., either train or test).
58
+ Defaults to ``None``.
59
+
60
+ Example:
61
+ The following example shows how to get training samples from BTech 01 category:
62
+
63
+ .. code-block:: python
64
+
65
+ >>> root = Path('./BTech')
66
+ >>> category = '01'
67
+ >>> path = root / category
68
+ >>> path
69
+ PosixPath('BTech/01')
70
+
71
+ >>> samples = make_btech_dataset(path, split='train')
72
+ >>> samples.head()
73
+ path split label image_path mask_path label_index
74
+ 0 BTech/01 train 01 BTech/01/train/ok/105.bmp BTech/01/ground_truth/ok/105.png 0
75
+ 1 BTech/01 train 01 BTech/01/train/ok/017.bmp BTech/01/ground_truth/ok/017.png 0
76
+ ...
77
+
78
+ Returns:
79
+ DataFrame: an output dataframe containing samples for the requested split (ie., train or test)
80
+ """
81
+ path = validate_path(path)
82
+
83
+ samples_list = [
84
+ (str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in (".bmp", ".png")
85
+ ]
86
+ if not samples_list:
87
+ msg = f"Found 0 images in {path}"
88
+ raise RuntimeError(msg)
89
+
90
+ samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"])
91
+ samples = samples[samples.split != "ground_truth"]
92
+
93
+ # Create mask_path column
94
+ # (safely handles cases where non-mask image_paths end with either .png or .bmp)
95
+ samples["mask_path"] = (
96
+ samples.path
97
+ + "/ground_truth/"
98
+ + samples.label
99
+ + "/"
100
+ + samples.image_path.str.rstrip("png").str.rstrip(".").str.rstrip("bmp").str.rstrip(".")
101
+ + ".png"
102
+ )
103
+
104
+ # Modify image_path column by converting to absolute path
105
+ samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path
106
+
107
+ # Good images don't have mask
108
+ samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = ""
109
+
110
+ # Create label index for normal (0) and anomalous (1) images.
111
+ samples.loc[(samples.label == "ok"), "label_index"] = LabelName.NORMAL
112
+ samples.loc[(samples.label != "ok"), "label_index"] = LabelName.ABNORMAL
113
+ samples.label_index = samples.label_index.astype(int)
114
+
115
+ # Get the data frame for the split.
116
+ if split:
117
+ samples = samples[samples.split == split]
118
+ samples = samples.reset_index(drop=True)
119
+
120
+ return samples
121
+
122
+
123
+ class BTechDataset(AnomalibDataset):
124
+ """Btech Dataset class.
125
+
126
+ Args:
127
+ root: Path to the BTech dataset
128
+ category: Name of the BTech category.
129
+ transform (Transform, optional): Transforms that should be applied to the input images.
130
+ Defaults to ``None``.
131
+ split: 'train', 'val' or 'test'
132
+ task: ``classification``, ``detection`` or ``segmentation``
133
+ create_validation_set: Create a validation subset in addition to the train and test subsets
134
+
135
+ Examples:
136
+ >>> from anomalib.data.image.btech import BTechDataset
137
+ >>> from anomalib.data.utils.transforms import get_transforms
138
+ >>> transform = get_transforms(image_size=256)
139
+ >>> dataset = BTechDataset(
140
+ ... task="classification",
141
+ ... transform=transform,
142
+ ... root='./datasets/BTech',
143
+ ... category='01',
144
+ ... )
145
+ >>> dataset[0].keys()
146
+ >>> dataset.setup()
147
+ dict_keys(['image'])
148
+
149
+ >>> dataset.split = "test"
150
+ >>> dataset[0].keys()
151
+ dict_keys(['image', 'image_path', 'label'])
152
+
153
+ >>> dataset.task = "segmentation"
154
+ >>> dataset.split = "train"
155
+ >>> dataset[0].keys()
156
+ dict_keys(['image'])
157
+
158
+ >>> dataset.split = "test"
159
+ >>> dataset[0].keys()
160
+ dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
161
+
162
+ >>> dataset[0]["image"].shape, dataset[0]["mask"].shape
163
+ (torch.Size([3, 256, 256]), torch.Size([256, 256]))
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ root: str | Path,
169
+ category: str,
170
+ transform: Transform | None = None,
171
+ split: str | Split | None = None,
172
+ task: TaskType | str = TaskType.SEGMENTATION,
173
+ ) -> None:
174
+ super().__init__(task, transform)
175
+
176
+ self.root_category = Path(root) / category
177
+ self.split = split
178
+ self.samples = make_btech_dataset(path=self.root_category, split=self.split)
179
+
180
+
181
+ class BTech(AnomalibDataModule):
182
+ """BTech Lightning Data Module.
183
+
184
+ Args:
185
+ root (Path | str): Path to the BTech dataset.
186
+ Defaults to ``"./datasets/BTech"``.
187
+ category (str): Name of the BTech category.
188
+ Defaults to ``"01"``.
189
+ train_batch_size (int, optional): Training batch size.
190
+ Defaults to ``32``.
191
+ eval_batch_size (int, optional): Eval batch size.
192
+ Defaults to ``32``.
193
+ num_workers (int, optional): Number of workers.
194
+ Defaults to ``8``.
195
+ task (TaskType, optional): Task type.
196
+ Defaults to ``TaskType.SEGMENTATION``.
197
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
198
+ Defaults to ``None``.
199
+ transform (Transform, optional): Transforms that should be applied to the input images.
200
+ Defaults to ``None``.
201
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
202
+ Defaults to ``None``.
203
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
204
+ Defaults to ``None``.
205
+ test_split_mode (TestSplitMode, optional): Setting that determines how the testing subset is obtained.
206
+ Defaults to ``TestSplitMode.FROM_DIR``.
207
+ test_split_ratio (float, optional): Fraction of images from the train set that will be reserved for testing.
208
+ Defaults to ``0.2``.
209
+ val_split_mode (ValSplitMode, optional): Setting that determines how the validation subset is obtained.
210
+ Defaults to ``ValSplitMode.SAME_AS_TEST``.
211
+ val_split_ratio (float, optional): Fraction of train or test images that will be reserved for validation.
212
+ Defaults to ``0.5``.
213
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
214
+ Defaults to ``None``.
215
+
216
+ Examples:
217
+ To create the BTech datamodule, we need to instantiate the class, and call the ``setup`` method.
218
+
219
+ >>> from anomalib.data import BTech
220
+ >>> datamodule = BTech(
221
+ ... root="./datasets/BTech",
222
+ ... category="01",
223
+ ... image_size=256,
224
+ ... train_batch_size=32,
225
+ ... eval_batch_size=32,
226
+ ... num_workers=8,
227
+ ... transform_config_train=None,
228
+ ... transform_config_eval=None,
229
+ ... )
230
+ >>> datamodule.setup()
231
+
232
+ To get the train dataloader and the first batch of data:
233
+
234
+ >>> i, data = next(enumerate(datamodule.train_dataloader()))
235
+ >>> data.keys()
236
+ dict_keys(['image'])
237
+ >>> data["image"].shape
238
+ torch.Size([32, 3, 256, 256])
239
+
240
+ To access the validation dataloader and the first batch of data:
241
+
242
+ >>> i, data = next(enumerate(datamodule.val_dataloader()))
243
+ >>> data.keys()
244
+ dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
245
+ >>> data["image"].shape, data["mask"].shape
246
+ (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))
247
+
248
+ Similarly, to access the test dataloader and the first batch of data:
249
+
250
+ >>> i, data = next(enumerate(datamodule.test_dataloader()))
251
+ >>> data.keys()
252
+ dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
253
+ >>> data["image"].shape, data["mask"].shape
254
+ (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ root: Path | str = "./datasets/BTech",
260
+ category: str = "01",
261
+ train_batch_size: int = 32,
262
+ eval_batch_size: int = 32,
263
+ num_workers: int = 8,
264
+ task: TaskType | str = TaskType.SEGMENTATION,
265
+ image_size: tuple[int, int] | None = None,
266
+ transform: Transform | None = None,
267
+ train_transform: Transform | None = None,
268
+ eval_transform: Transform | None = None,
269
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
270
+ test_split_ratio: float = 0.2,
271
+ val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
272
+ val_split_ratio: float = 0.5,
273
+ seed: int | None = None,
274
+ ) -> None:
275
+ super().__init__(
276
+ train_batch_size=train_batch_size,
277
+ eval_batch_size=eval_batch_size,
278
+ num_workers=num_workers,
279
+ image_size=image_size,
280
+ transform=transform,
281
+ train_transform=train_transform,
282
+ eval_transform=eval_transform,
283
+ test_split_mode=test_split_mode,
284
+ test_split_ratio=test_split_ratio,
285
+ val_split_mode=val_split_mode,
286
+ val_split_ratio=val_split_ratio,
287
+ seed=seed,
288
+ )
289
+
290
+ self.root = Path(root)
291
+ self.category = category
292
+ self.task = TaskType(task)
293
+
294
+ def _setup(self, _stage: str | None = None) -> None:
295
+ self.train_data = BTechDataset(
296
+ task=self.task,
297
+ transform=self.train_transform,
298
+ split=Split.TRAIN,
299
+ root=self.root,
300
+ category=self.category,
301
+ )
302
+ self.test_data = BTechDataset(
303
+ task=self.task,
304
+ transform=self.eval_transform,
305
+ split=Split.TEST,
306
+ root=self.root,
307
+ category=self.category,
308
+ )
309
+
310
+ def prepare_data(self) -> None:
311
+ """Download the dataset if not available.
312
+
313
+ This method checks if the specified dataset is available in the file system.
314
+ If not, it downloads and extracts the dataset into the appropriate directory.
315
+
316
+ Example:
317
+ Assume the dataset is not available on the file system.
318
+ Here's how the directory structure looks before and after calling the
319
+ `prepare_data` method:
320
+
321
+ Before:
322
+
323
+ .. code-block:: bash
324
+
325
+ $ tree datasets
326
+ datasets
327
+ ├── dataset1
328
+ └── dataset2
329
+
330
+ Calling the method:
331
+
332
+ .. code-block:: python
333
+
334
+ >> datamodule = BTech(root="./datasets/BTech", category="01")
335
+ >> datamodule.prepare_data()
336
+
337
+ After:
338
+
339
+ .. code-block:: bash
340
+
341
+ $ tree datasets
342
+ datasets
343
+ ├── dataset1
344
+ ├── dataset2
345
+ └── BTech
346
+ ├── 01
347
+ ├── 02
348
+ └── 03
349
+ """
350
+ if (self.root / self.category).is_dir():
351
+ logger.info("Found the dataset.")
352
+ else:
353
+ download_and_extract(self.root.parent, DOWNLOAD_INFO)
354
+
355
+ # rename folder and convert images
356
+ logger.info("Renaming the dataset directory")
357
+ shutil.move(src=str(self.root.parent / "BTech_Dataset_transformed"), dst=str(self.root))
358
+ logger.info("Convert the bmp formats to png to have consistent image extensions")
359
+ for filename in tqdm(self.root.glob("**/*.bmp"), desc="Converting bmp to png"):
360
+ image = cv2.imread(str(filename))
361
+ cv2.imwrite(str(filename.with_suffix(".png")), image)
362
+ filename.unlink()
anomalib/data/image/folder.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Folder Dataset.
2
+
3
+ This script creates a custom dataset from a folder.
4
+ """
5
+
6
+ # Copyright (C) 2022-2024 Intel Corporation
7
+ # SPDX-License-Identifier: Apache-2.0
8
+
9
+ from collections.abc import Sequence
10
+ from pathlib import Path
11
+
12
+ from pandas import DataFrame
13
+ from torchvision.transforms.v2 import Transform
14
+
15
+ from anomalib import TaskType
16
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
17
+ from anomalib.data.errors import MisMatchError
18
+ from anomalib.data.utils import (
19
+ DirType,
20
+ LabelName,
21
+ Split,
22
+ TestSplitMode,
23
+ ValSplitMode,
24
+ )
25
+ from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path
26
+
27
+
28
+ def make_folder_dataset(
29
+ normal_dir: str | Path | Sequence[str | Path],
30
+ root: str | Path | None = None,
31
+ abnormal_dir: str | Path | Sequence[str | Path] | None = None,
32
+ normal_test_dir: str | Path | Sequence[str | Path] | None = None,
33
+ mask_dir: str | Path | Sequence[str | Path] | None = None,
34
+ split: str | Split | None = None,
35
+ extensions: tuple[str, ...] | None = None,
36
+ ) -> DataFrame:
37
+ """Make Folder Dataset.
38
+
39
+ Args:
40
+ normal_dir (str | Path | Sequence): Path to the directory containing normal images.
41
+ root (str | Path | None): Path to the root directory of the dataset.
42
+ Defaults to ``None``.
43
+ abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images.
44
+ Defaults to ``None``.
45
+ normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing normal images for
46
+ the test dataset. Normal test images will be a split of `normal_dir` if `None`.
47
+ Defaults to ``None``.
48
+ mask_dir (str | Path | Sequence | None, optional): Path to the directory containing the mask annotations.
49
+ Defaults to ``None``.
50
+ split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST).
51
+ Defaults to ``None``.
52
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
53
+ Defaults to ``None``.
54
+
55
+ Returns:
56
+ DataFrame: an output dataframe containing samples for the requested split (ie., train or test).
57
+
58
+ Examples:
59
+ Assume that we would like to use this ``make_folder_dataset`` to create a dataset from a folder.
60
+ We could then create the dataset as follows,
61
+
62
+ .. code-block:: python
63
+
64
+ folder_df = make_folder_dataset(
65
+ normal_dir=dataset_root / "good",
66
+ abnormal_dir=dataset_root / "crack",
67
+ split="train",
68
+ )
69
+ folder_df.head()
70
+
71
+ .. code-block:: bash
72
+
73
+ image_path label label_index mask_path split
74
+ 0 ./toy/good/00.jpg DirType.NORMAL 0 Split.TRAIN
75
+ 1 ./toy/good/01.jpg DirType.NORMAL 0 Split.TRAIN
76
+ 2 ./toy/good/02.jpg DirType.NORMAL 0 Split.TRAIN
77
+ 3 ./toy/good/03.jpg DirType.NORMAL 0 Split.TRAIN
78
+ 4 ./toy/good/04.jpg DirType.NORMAL 0 Split.TRAIN
79
+ """
80
+
81
+ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | None) -> list[Path]:
82
+ """Convert path to list of paths.
83
+
84
+ Args:
85
+ path (str | Path | Sequence | None): Path to replace with Sequence[str | Path].
86
+
87
+ Examples:
88
+ >>> _resolve_path_and_convert_to_list("dir")
89
+ [Path("path/to/dir")]
90
+ >>> _resolve_path_and_convert_to_list(["dir1", "dir2"])
91
+ [Path("path/to/dir1"), Path("path/to/dir2")]
92
+
93
+ Returns:
94
+ list[Path]: The result of path replaced by Sequence[str | Path].
95
+ """
96
+ if isinstance(path, Sequence) and not isinstance(path, str):
97
+ return [validate_and_resolve_path(dir_path, root) for dir_path in path]
98
+ return [validate_and_resolve_path(path, root)] if path is not None else []
99
+
100
+ # All paths are changed to the List[Path] type and used.
101
+ normal_dir = _resolve_path_and_convert_to_list(normal_dir)
102
+ abnormal_dir = _resolve_path_and_convert_to_list(abnormal_dir)
103
+ normal_test_dir = _resolve_path_and_convert_to_list(normal_test_dir)
104
+ mask_dir = _resolve_path_and_convert_to_list(mask_dir)
105
+ if len(normal_dir) == 0:
106
+ msg = "A folder location must be provided in normal_dir."
107
+ raise ValueError(msg)
108
+
109
+ filenames = []
110
+ labels = []
111
+ dirs = {DirType.NORMAL: normal_dir}
112
+
113
+ if abnormal_dir:
114
+ dirs[DirType.ABNORMAL] = abnormal_dir
115
+
116
+ if normal_test_dir:
117
+ dirs[DirType.NORMAL_TEST] = normal_test_dir
118
+
119
+ if mask_dir:
120
+ dirs[DirType.MASK] = mask_dir
121
+
122
+ for dir_type, paths in dirs.items():
123
+ for path in paths:
124
+ filename, label = _prepare_files_labels(path, dir_type, extensions)
125
+ filenames += filename
126
+ labels += label
127
+
128
+ samples = DataFrame({"image_path": filenames, "label": labels})
129
+ samples = samples.sort_values(by="image_path", ignore_index=True)
130
+
131
+ # Create label index for normal (0) and abnormal (1) images.
132
+ samples.loc[
133
+ (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST),
134
+ "label_index",
135
+ ] = LabelName.NORMAL
136
+ samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
137
+ samples.label_index = samples.label_index.astype("Int64")
138
+
139
+ # If a path to mask is provided, add it to the sample dataframe.
140
+
141
+ if len(mask_dir) > 0 and len(abnormal_dir) > 0:
142
+ samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
143
+ samples.label == DirType.MASK
144
+ ].image_path.to_numpy()
145
+ samples["mask_path"] = samples["mask_path"].fillna("")
146
+ samples = samples.astype({"mask_path": "str"})
147
+
148
+ # make sure all every rgb image has a corresponding mask image.
149
+ if not (
150
+ samples.loc[samples.label_index == LabelName.ABNORMAL]
151
+ .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
152
+ .all()
153
+ ):
154
+ msg = """Mismatch between anomalous images and mask images. Make sure the mask files "
155
+ "folder follow the same naming convention as the anomalous images in the dataset "
156
+ "(e.g. image: '000.png', mask: '000.png')."""
157
+ raise MisMatchError(msg)
158
+
159
+ else:
160
+ samples["mask_path"] = ""
161
+
162
+ # remove all the rows with temporal image samples that have already been assigned
163
+ samples = samples.loc[
164
+ (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
165
+ ]
166
+
167
+ # Ensure the pathlib objects are converted to str.
168
+ # This is because torch dataloader doesn't like pathlib.
169
+ samples = samples.astype({"image_path": "str"})
170
+
171
+ # Create train/test split.
172
+ # By default, all the normal samples are assigned as train.
173
+ # and all the abnormal samples are test.
174
+ samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
175
+ samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST
176
+
177
+ # Get the data frame for the split.
178
+ if split:
179
+ samples = samples[samples.split == split]
180
+ samples = samples.reset_index(drop=True)
181
+
182
+ return samples
183
+
184
+
185
+ class FolderDataset(AnomalibDataset):
186
+ """Folder dataset.
187
+
188
+ This class is used to create a dataset from a folder. The class utilizes the Torch Dataset class.
189
+
190
+ Args:
191
+ name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
192
+ task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
193
+ transform (Transform, optional): Transforms that should be applied to the input images.
194
+ Defaults to ``None``.
195
+ normal_dir (str | Path | Sequence): Path to the directory containing normal images.
196
+ root (str | Path | None): Root folder of the dataset.
197
+ Defaults to ``None``.
198
+ abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images.
199
+ Defaults to ``None``.
200
+ normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing
201
+ normal images for the test dataset.
202
+ Defaults to ``None``.
203
+ mask_dir (str | Path | Sequence | None, optional): Path to the directory containing
204
+ the mask annotations.
205
+ Defaults to ``None``.
206
+ split (str | Split | None): Fixed subset split that follows from folder structure on file system.
207
+ Choose from [Split.FULL, Split.TRAIN, Split.TEST]
208
+ Defaults to ``None``.
209
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
210
+ Defaults to ``None``.
211
+
212
+ Raises:
213
+ ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is
214
+ provided, `task` should be set to `segmentation`.
215
+
216
+ Examples:
217
+ Assume that we would like to use this ``FolderDataset`` to create a dataset from a folder for a classification
218
+ task. We could first create the transforms,
219
+
220
+ >>> from anomalib.data.utils import InputNormalizationMethod, get_transforms
221
+ >>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE)
222
+
223
+ We could then create the dataset as follows,
224
+
225
+ .. code-block:: python
226
+
227
+ folder_dataset_classification_train = FolderDataset(
228
+ normal_dir=dataset_root / "good",
229
+ abnormal_dir=dataset_root / "crack",
230
+ split="train",
231
+ transform=transform,
232
+ task=TaskType.CLASSIFICATION,
233
+ )
234
+
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ name: str,
240
+ task: TaskType,
241
+ normal_dir: str | Path | Sequence[str | Path],
242
+ transform: Transform | None = None,
243
+ root: str | Path | None = None,
244
+ abnormal_dir: str | Path | Sequence[str | Path] | None = None,
245
+ normal_test_dir: str | Path | Sequence[str | Path] | None = None,
246
+ mask_dir: str | Path | Sequence[str | Path] | None = None,
247
+ split: str | Split | None = None,
248
+ extensions: tuple[str, ...] | None = None,
249
+ ) -> None:
250
+ super().__init__(task, transform)
251
+
252
+ self._name = name
253
+ self.split = split
254
+ self.root = root
255
+ self.normal_dir = normal_dir
256
+ self.abnormal_dir = abnormal_dir
257
+ self.normal_test_dir = normal_test_dir
258
+ self.mask_dir = mask_dir
259
+ self.extensions = extensions
260
+
261
+ self.samples = make_folder_dataset(
262
+ root=self.root,
263
+ normal_dir=self.normal_dir,
264
+ abnormal_dir=self.abnormal_dir,
265
+ normal_test_dir=self.normal_test_dir,
266
+ mask_dir=self.mask_dir,
267
+ split=self.split,
268
+ extensions=self.extensions,
269
+ )
270
+
271
+ @property
272
+ def name(self) -> str:
273
+ """Name of the dataset.
274
+
275
+ Folder dataset overrides the name property to provide a custom name.
276
+ """
277
+ return self._name
278
+
279
+
280
+ class Folder(AnomalibDataModule):
281
+ """Folder DataModule.
282
+
283
+ Args:
284
+ name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
285
+ normal_dir (str | Path | Sequence): Name of the directory containing normal images.
286
+ root (str | Path | None): Path to the root folder containing normal and abnormal dirs.
287
+ Defaults to ``None``.
288
+ abnormal_dir (str | Path | None | Sequence): Name of the directory containing abnormal images.
289
+ Defaults to ``None``.
290
+ normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing
291
+ normal images for the test dataset.
292
+ Defaults to ``None``.
293
+ mask_dir (str | Path | Sequence | None, optional): Path to the directory containing
294
+ the mask annotations.
295
+ Defaults to ``None``.
296
+ normal_split_ratio (float, optional): Ratio to split normal training images and add to the
297
+ test set in case test set doesn't contain any normal images.
298
+ Defaults to 0.2.
299
+ extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the
300
+ directory.
301
+ Defaults to ``None``.
302
+ train_batch_size (int, optional): Training batch size.
303
+ Defaults to ``32``.
304
+ eval_batch_size (int, optional): Validation, test and predict batch size.
305
+ Defaults to ``32``.
306
+ num_workers (int, optional): Number of workers.
307
+ Defaults to ``8``.
308
+ task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``.
309
+ Defaults to ``segmentation``.
310
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
311
+ Defaults to ``None``.
312
+ transform (Transform, optional): Transforms that should be applied to the input images.
313
+ Defaults to ``None``.
314
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
315
+ Defaults to ``None``.
316
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
317
+ Defaults to ``None``.
318
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
319
+ Defaults to ``TestSplitMode.FROM_DIR``.
320
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
321
+ Defaults to ``0.2``.
322
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
323
+ Defaults to ``ValSplitMode.FROM_TEST``.
324
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
325
+ Defaults to ``0.5``.
326
+ seed (int | None, optional): Seed used during random subset splitting.
327
+ Defaults to ``None``.
328
+
329
+ Examples:
330
+ The following code demonstrates how to use the ``Folder`` datamodule. Assume that the dataset is structured
331
+ as follows:
332
+
333
+ .. code-block:: bash
334
+
335
+ $ tree sample_dataset
336
+ sample_dataset
337
+ ├── colour
338
+ │ ├── 00.jpg
339
+ │ ├── ...
340
+ │ └── x.jpg
341
+ ├── crack
342
+ │ ├── 00.jpg
343
+ │ ├── ...
344
+ │ └── y.jpg
345
+ ├── good
346
+ │ ├── ...
347
+ │ └── z.jpg
348
+ ├── LICENSE
349
+ └── mask
350
+ ├── colour
351
+ │ ├── ...
352
+ │ └── x.jpg
353
+ └── crack
354
+ ├── ...
355
+ └── y.jpg
356
+
357
+ .. code-block:: python
358
+
359
+ folder_datamodule = Folder(
360
+ root=dataset_root,
361
+ normal_dir="good",
362
+ abnormal_dir="crack",
363
+ task=TaskType.SEGMENTATION,
364
+ mask_dir=dataset_root / "mask" / "crack",
365
+ image_size=256,
366
+ normalization=InputNormalizationMethod.NONE,
367
+ )
368
+ folder_datamodule.setup()
369
+
370
+ To access the training images,
371
+
372
+ .. code-block:: python
373
+
374
+ >> i, data = next(enumerate(folder_datamodule.train_dataloader()))
375
+ >> print(data.keys(), data["image"].shape)
376
+
377
+ To access the test images,
378
+
379
+ .. code-block:: python
380
+
381
+ >> i, data = next(enumerate(folder_datamodule.test_dataloader()))
382
+ >> print(data.keys(), data["image"].shape)
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ name: str,
388
+ normal_dir: str | Path | Sequence[str | Path],
389
+ root: str | Path | None = None,
390
+ abnormal_dir: str | Path | Sequence[str | Path] | None = None,
391
+ normal_test_dir: str | Path | Sequence[str | Path] | None = None,
392
+ mask_dir: str | Path | Sequence[str | Path] | None = None,
393
+ normal_split_ratio: float = 0.2,
394
+ extensions: tuple[str] | None = None,
395
+ train_batch_size: int = 32,
396
+ eval_batch_size: int = 32,
397
+ num_workers: int = 8,
398
+ task: TaskType | str = TaskType.SEGMENTATION,
399
+ image_size: tuple[int, int] | None = None,
400
+ transform: Transform | None = None,
401
+ train_transform: Transform | None = None,
402
+ eval_transform: Transform | None = None,
403
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
404
+ test_split_ratio: float = 0.2,
405
+ val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
406
+ val_split_ratio: float = 0.5,
407
+ seed: int | None = None,
408
+ ) -> None:
409
+ self._name = name
410
+ self.root = root
411
+ self.normal_dir = normal_dir
412
+ self.abnormal_dir = abnormal_dir
413
+ self.normal_test_dir = normal_test_dir
414
+ self.mask_dir = mask_dir
415
+ self.task = TaskType(task)
416
+ self.extensions = extensions
417
+ test_split_mode = TestSplitMode(test_split_mode)
418
+ val_split_mode = ValSplitMode(val_split_mode)
419
+ super().__init__(
420
+ train_batch_size=train_batch_size,
421
+ eval_batch_size=eval_batch_size,
422
+ num_workers=num_workers,
423
+ test_split_mode=test_split_mode,
424
+ test_split_ratio=test_split_ratio,
425
+ val_split_mode=val_split_mode,
426
+ val_split_ratio=val_split_ratio,
427
+ image_size=image_size,
428
+ transform=transform,
429
+ train_transform=train_transform,
430
+ eval_transform=eval_transform,
431
+ seed=seed,
432
+ )
433
+
434
+ if task == TaskType.SEGMENTATION and test_split_mode == TestSplitMode.FROM_DIR and mask_dir is None:
435
+ msg = (
436
+ f"Segmentation task requires mask directory if test_split_mode is {test_split_mode}. "
437
+ "You could set test_split_mode to {TestSplitMode.NONE} or provide a mask directory."
438
+ )
439
+ raise ValueError(
440
+ msg,
441
+ )
442
+
443
+ self.normal_split_ratio = normal_split_ratio
444
+
445
+ def _setup(self, _stage: str | None = None) -> None:
446
+ self.train_data = FolderDataset(
447
+ name=self.name,
448
+ task=self.task,
449
+ transform=self.train_transform,
450
+ split=Split.TRAIN,
451
+ root=self.root,
452
+ normal_dir=self.normal_dir,
453
+ abnormal_dir=self.abnormal_dir,
454
+ normal_test_dir=self.normal_test_dir,
455
+ mask_dir=self.mask_dir,
456
+ extensions=self.extensions,
457
+ )
458
+
459
+ self.test_data = FolderDataset(
460
+ name=self.name,
461
+ task=self.task,
462
+ transform=self.eval_transform,
463
+ split=Split.TEST,
464
+ root=self.root,
465
+ normal_dir=self.normal_dir,
466
+ abnormal_dir=self.abnormal_dir,
467
+ normal_test_dir=self.normal_test_dir,
468
+ mask_dir=self.mask_dir,
469
+ extensions=self.extensions,
470
+ )
471
+
472
+ @property
473
+ def name(self) -> str:
474
+ """Name of the datamodule.
475
+
476
+ Folder datamodule overrides the name property to provide a custom name.
477
+ """
478
+ return self._name
anomalib/data/image/kolektor.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kolektor Surface-Defect Dataset (CC BY-NC-SA 4.0).
2
+
3
+ Description:
4
+ This script provides a PyTorch Dataset, DataLoader, and PyTorch Lightning DataModule for the Kolektor
5
+ Surface-Defect dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset <https://www.vicos.si/resources/kolektorsdd/>`_.
6
+
7
+ License:
8
+ The Kolektor Surface-Defect dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike
9
+ 4.0 International License (CC BY-NC-SA 4.0). For more details, visit
10
+ `Creative Commons License <https://creativecommons.org/licenses/by-nc-sa/4.0/>`_.
11
+
12
+ Reference:
13
+ Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. "Segmentation-based deep-learning approach
14
+ for surface-defect detection." Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776.
15
+ """
16
+
17
+ # Copyright (C) 2023-2024 Intel Corporation
18
+ # SPDX-License-Identifier: Apache-2.0
19
+
20
+ import logging
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ from cv2 import imread
25
+ from pandas import DataFrame
26
+ from sklearn.model_selection import train_test_split
27
+ from torchvision.transforms.v2 import Transform
28
+
29
+ from anomalib import TaskType
30
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
31
+ from anomalib.data.errors import MisMatchError
32
+ from anomalib.data.utils import (
33
+ DownloadInfo,
34
+ Split,
35
+ TestSplitMode,
36
+ ValSplitMode,
37
+ download_and_extract,
38
+ validate_path,
39
+ )
40
+
41
+ __all__ = ["Kolektor", "KolektorDataset", "make_kolektor_dataset"]
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ DOWNLOAD_INFO = DownloadInfo(
46
+ name="kolektor",
47
+ url="https://go.vicos.si/kolektorsdd",
48
+ hashsum="65dc621693418585de9c4467d1340ea7958a6181816f0dc2883a1e8b61f9d4dc",
49
+ filename="KolektorSDD.zip",
50
+ )
51
+
52
+
53
+ def is_mask_anomalous(path: str) -> int:
54
+ """Check if a mask shows defects.
55
+
56
+ Args:
57
+ path (str): Path to the mask file.
58
+
59
+ Returns:
60
+ int: 1 if the mask shows defects, 0 otherwise.
61
+
62
+ Example:
63
+ Assume that the following image is a mask for a defective image.
64
+ Then the function will return 1.
65
+
66
+ >>> from anomalib.data.image.kolektor import is_mask_anomalous
67
+ >>> path = './KolektorSDD/kos01/Part0_label.bmp'
68
+ >>> is_mask_anomalous(path)
69
+ 1
70
+ """
71
+ img_arr = imread(path)
72
+ if np.all(img_arr == 0):
73
+ return 0
74
+ return 1
75
+
76
+
77
+ def make_kolektor_dataset(
78
+ root: str | Path,
79
+ train_split_ratio: float = 0.8,
80
+ split: str | Split | None = None,
81
+ ) -> DataFrame:
82
+ """Create Kolektor samples by parsing the Kolektor data file structure.
83
+
84
+ The files are expected to follow this structure:
85
+ - Image files: `path/to/dataset/item/image_filename.jpg`, `path/to/dataset/kos01/Part0.jpg`
86
+ - Mask files: `path/to/dataset/item/mask_filename.bmp`, `path/to/dataset/kos01/Part0_label.bmp`
87
+
88
+ This function creates a DataFrame to store the parsed information in the following format:
89
+
90
+ +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
91
+ | | path | item | split | label | image_path | mask_path | label_index |
92
+ +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
93
+ | 0 | KolektorSDD | kos01 | test | Bad | /path/to/image_file | /path/to/mask_file | 1 |
94
+ +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
95
+
96
+ Args:
97
+ root (Path): Path to the dataset.
98
+ train_split_ratio (float, optional): Ratio for splitting good images into train/test sets.
99
+ Defaults to ``0.8``.
100
+ split (str | Split | None, optional): Dataset split (either 'train' or 'test').
101
+ Defaults to ``None``.
102
+
103
+ Returns:
104
+ pandas.DataFrame: An output DataFrame containing the samples of the dataset.
105
+
106
+ Example:
107
+ The following example shows how to get training samples from the Kolektor Dataset:
108
+
109
+ >>> from pathlib import Path
110
+ >>> root = Path('./KolektorSDD/')
111
+ >>> samples = create_kolektor_samples(root, train_split_ratio=0.8)
112
+ >>> samples.head()
113
+ path item split label image_path mask_path label_index
114
+ 0 KolektorSDD kos01 train Good KolektorSDD/kos01/Part0.jpg KolektorSDD/kos01/Part0_label.bmp 0
115
+ 1 KolektorSDD kos01 train Good KolektorSDD/kos01/Part1.jpg KolektorSDD/kos01/Part1_label.bmp 0
116
+ 2 KolektorSDD kos01 train Good KolektorSDD/kos01/Part2.jpg KolektorSDD/kos01/Part2_label.bmp 0
117
+ 3 KolektorSDD kos01 test Good KolektorSDD/kos01/Part3.jpg KolektorSDD/kos01/Part3_label.bmp 0
118
+ 4 KolektorSDD kos01 train Good KolektorSDD/kos01/Part4.jpg KolektorSDD/kos01/Part4_label.bmp 0
119
+ """
120
+ root = validate_path(root)
121
+
122
+ # Get list of images and masks
123
+ samples_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".jpg"]
124
+ masks_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".bmp"]
125
+
126
+ if not samples_list:
127
+ msg = f"Found 0 images in {root}"
128
+ raise RuntimeError(msg)
129
+
130
+ # Create dataframes
131
+ samples = DataFrame(samples_list, columns=["path", "item", "image_path"])
132
+ masks = DataFrame(masks_list, columns=["path", "item", "image_path"])
133
+
134
+ # Modify image_path column by converting to absolute path
135
+ samples["image_path"] = samples.path + "/" + samples.item + "/" + samples.image_path
136
+ masks["image_path"] = masks.path + "/" + masks.item + "/" + masks.image_path
137
+
138
+ # Sort samples by image path
139
+ samples = samples.sort_values(by="image_path", ignore_index=True)
140
+ masks = masks.sort_values(by="image_path", ignore_index=True)
141
+
142
+ # Add mask paths for sample images
143
+ samples["mask_path"] = masks.image_path.to_numpy()
144
+
145
+ # Use is_good func to configure the label_index
146
+ samples["label_index"] = samples["mask_path"].apply(is_mask_anomalous)
147
+ samples.label_index = samples.label_index.astype(int)
148
+
149
+ # Use label indexes to label data
150
+ samples.loc[(samples.label_index == 0), "label"] = "Good"
151
+ samples.loc[(samples.label_index == 1), "label"] = "Bad"
152
+
153
+ # Add all 'Bad' samples to test set
154
+ samples.loc[(samples.label == "Bad"), "split"] = "test"
155
+
156
+ # Divide 'good' images to train/test on 0.8/0.2 ratio
157
+ train_samples, test_samples = train_test_split(
158
+ samples[samples.label == "Good"],
159
+ train_size=train_split_ratio,
160
+ random_state=42,
161
+ )
162
+ samples.loc[train_samples.index, "split"] = "train"
163
+ samples.loc[test_samples.index, "split"] = "test"
164
+
165
+ # Reorder columns
166
+ samples = samples[["path", "item", "split", "label", "image_path", "mask_path", "label_index"]]
167
+
168
+ # assert that the right mask files are associated with the right test images
169
+ if not (
170
+ samples.loc[samples.label_index == 1]
171
+ .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
172
+ .all()
173
+ ):
174
+ msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
175
+ follow the same naming convention as the anomalous images in the dataset
176
+ (e.g. image: 'Part0.jpg', mask: 'Part0_label.bmp')."""
177
+ raise MisMatchError(msg)
178
+
179
+ # Get the dataframe for the required split
180
+ if split:
181
+ samples = samples[samples.split == split].reset_index(drop=True)
182
+
183
+ return samples
184
+
185
+
186
+ class KolektorDataset(AnomalibDataset):
187
+ """Kolektor dataset class.
188
+
189
+ Args:
190
+ task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
191
+ root (Path | str): Path to the root of the dataset
192
+ Defaults to ``./datasets/kolektor``.
193
+ transform (Transform, optional): Transforms that should be applied to the input images.
194
+ Defaults to ``None``.
195
+ split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
196
+ Defaults to ``None``.
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ task: TaskType,
202
+ root: Path | str = "./datasets/kolektor",
203
+ transform: Transform | None = None,
204
+ split: str | Split | None = None,
205
+ ) -> None:
206
+ super().__init__(task=task, transform=transform)
207
+
208
+ self.root = root
209
+ self.split = split
210
+ self.samples = make_kolektor_dataset(self.root, train_split_ratio=0.8, split=self.split)
211
+
212
+
213
+ class Kolektor(AnomalibDataModule):
214
+ """Kolektor Datamodule.
215
+
216
+ Args:
217
+ root (Path | str): Path to the root of the dataset
218
+ train_batch_size (int, optional): Training batch size.
219
+ Defaults to ``32``.
220
+ eval_batch_size (int, optional): Test batch size.
221
+ Defaults to ``32``.
222
+ num_workers (int, optional): Number of workers.
223
+ Defaults to ``8``.
224
+ task TaskType): Task type, 'classification', 'detection' or 'segmentation'
225
+ Defaults to ``TaskType.SEGMENTATION``.
226
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
227
+ Defaults to ``None``.
228
+ transform (Transform, optional): Transforms that should be applied to the input images.
229
+ Defaults to ``None``.
230
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
231
+ Defaults to ``None``.
232
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
233
+ Defaults to ``None``.
234
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
235
+ Defaults to ``TestSplitMode.FROM_DIR``
236
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
237
+ Defaults to ``0.2``
238
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
239
+ Defaults to ``ValSplitMode.SAME_AS_TEST``
240
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
241
+ Defaults to ``0.5``
242
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
243
+ Defaults to ``None``.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ root: Path | str = "./datasets/kolektor",
249
+ train_batch_size: int = 32,
250
+ eval_batch_size: int = 32,
251
+ num_workers: int = 8,
252
+ task: TaskType | str = TaskType.SEGMENTATION,
253
+ image_size: tuple[int, int] | None = None,
254
+ transform: Transform | None = None,
255
+ train_transform: Transform | None = None,
256
+ eval_transform: Transform | None = None,
257
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
258
+ test_split_ratio: float = 0.2,
259
+ val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
260
+ val_split_ratio: float = 0.5,
261
+ seed: int | None = None,
262
+ ) -> None:
263
+ super().__init__(
264
+ train_batch_size=train_batch_size,
265
+ eval_batch_size=eval_batch_size,
266
+ num_workers=num_workers,
267
+ image_size=image_size,
268
+ transform=transform,
269
+ train_transform=train_transform,
270
+ eval_transform=eval_transform,
271
+ test_split_mode=test_split_mode,
272
+ test_split_ratio=test_split_ratio,
273
+ val_split_mode=val_split_mode,
274
+ val_split_ratio=val_split_ratio,
275
+ seed=seed,
276
+ )
277
+
278
+ self.task = TaskType(task)
279
+ self.root = Path(root)
280
+
281
+ def _setup(self, _stage: str | None = None) -> None:
282
+ self.train_data = KolektorDataset(
283
+ task=self.task,
284
+ transform=self.train_transform,
285
+ split=Split.TRAIN,
286
+ root=self.root,
287
+ )
288
+ self.test_data = KolektorDataset(
289
+ task=self.task,
290
+ transform=self.eval_transform,
291
+ split=Split.TEST,
292
+ root=self.root,
293
+ )
294
+
295
+ def prepare_data(self) -> None:
296
+ """Download the dataset if not available.
297
+
298
+ This method checks if the specified dataset is available in the file system.
299
+ If not, it downloads and extracts the dataset into the appropriate directory.
300
+
301
+ Example:
302
+ Assume the dataset is not available on the file system.
303
+ Here's how the directory structure looks before and after calling the
304
+ `prepare_data` method:
305
+
306
+ Before:
307
+
308
+ .. code-block:: bash
309
+
310
+ $ tree datasets
311
+ datasets
312
+ ├── dataset1
313
+ └── dataset2
314
+
315
+ Calling the method:
316
+
317
+ .. code-block:: python
318
+
319
+ >> datamodule = Kolektor(root="./datasets/kolektor")
320
+ >> datamodule.prepare_data()
321
+
322
+ After:
323
+
324
+ .. code-block:: bash
325
+
326
+ $ tree datasets
327
+ datasets
328
+ ├── dataset1
329
+ ├── dataset2
330
+ └── kolektor
331
+ ├── kolektorsdd
332
+ ├── kos01
333
+ ├── ...
334
+ └── kos50
335
+ ├── Part0.jpg
336
+ ├── Part0_label.bmp
337
+ └── ...
338
+ """
339
+ if (self.root).is_dir():
340
+ logger.info("Found the dataset.")
341
+ else:
342
+ download_and_extract(self.root, DOWNLOAD_INFO)
anomalib/data/image/mvtec.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MVTec AD Dataset (CC BY-NC-SA 4.0).
2
+
3
+ Description:
4
+ This script contains PyTorch Dataset, Dataloader and PyTorch Lightning
5
+ DataModule for the MVTec AD dataset. If the dataset is not on the file system,
6
+ the script downloads and extracts the dataset and create PyTorch data objects.
7
+
8
+ License:
9
+ MVTec AD dataset is released under the Creative Commons
10
+ Attribution-NonCommercial-ShareAlike 4.0 International License
11
+ (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
12
+
13
+ References:
14
+ - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger:
15
+ The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for
16
+ Unsupervised Anomaly Detection; in: International Journal of Computer Vision
17
+ 129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4.
18
+
19
+ - Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD —
20
+ A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection;
21
+ in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
22
+ 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982.
23
+ """
24
+
25
+ # Copyright (C) 2022-2024 Intel Corporation
26
+ # SPDX-License-Identifier: Apache-2.0
27
+
28
+ import logging
29
+ from collections.abc import Sequence
30
+ from pathlib import Path
31
+
32
+ from pandas import DataFrame
33
+ from torchvision.transforms.v2 import Transform
34
+
35
+ from anomalib import TaskType
36
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
37
+ from anomalib.data.errors import MisMatchError
38
+ from anomalib.data.utils import (
39
+ DownloadInfo,
40
+ LabelName,
41
+ Split,
42
+ TestSplitMode,
43
+ ValSplitMode,
44
+ download_and_extract,
45
+ validate_path,
46
+ )
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ IMG_EXTENSIONS = (".png", ".PNG")
52
+
53
+ DOWNLOAD_INFO = DownloadInfo(
54
+ name="mvtec",
55
+ url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094"
56
+ "/mvtec_anomaly_detection.tar.xz",
57
+ hashsum="cf4313b13603bec67abb49ca959488f7eedce2a9f7795ec54446c649ac98cd3d",
58
+ )
59
+
60
+ CATEGORIES = (
61
+ "bottle",
62
+ "cable",
63
+ "capsule",
64
+ "carpet",
65
+ "grid",
66
+ "hazelnut",
67
+ "leather",
68
+ "metal_nut",
69
+ "pill",
70
+ "screw",
71
+ "tile",
72
+ "toothbrush",
73
+ "transistor",
74
+ "wood",
75
+ "zipper",
76
+ )
77
+
78
+
79
+ def make_mvtec_dataset(
80
+ root: str | Path,
81
+ split: str | Split | None = None,
82
+ extensions: Sequence[str] | None = None,
83
+ ) -> DataFrame:
84
+ """Create MVTec AD samples by parsing the MVTec AD data file structure.
85
+
86
+ The files are expected to follow the structure:
87
+ path/to/dataset/split/category/image_filename.png
88
+ path/to/dataset/ground_truth/category/mask_filename.png
89
+
90
+ This function creates a dataframe to store the parsed information based on the following format:
91
+
92
+ +---+---------------+-------+---------+---------------+---------------------------------------+-------------+
93
+ | | path | split | label | image_path | mask_path | label_index |
94
+ +===+===============+=======+=========+===============+=======================================+=============+
95
+ | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 |
96
+ +---+---------------+-------+---------+---------------+---------------------------------------+-------------+
97
+
98
+ Args:
99
+ root (Path): Path to dataset
100
+ split (str | Split | None, optional): Dataset split (ie., either train or test).
101
+ Defaults to ``None``.
102
+ extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset.
103
+ Defaults to ``None``.
104
+
105
+ Examples:
106
+ The following example shows how to get training samples from MVTec AD bottle category:
107
+
108
+ >>> root = Path('./MVTec')
109
+ >>> category = 'bottle'
110
+ >>> path = root / category
111
+ >>> path
112
+ PosixPath('MVTec/bottle')
113
+
114
+ >>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0)
115
+ >>> samples.head()
116
+ path split label image_path mask_path label_index
117
+ 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0
118
+ 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0
119
+ 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0
120
+ 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0
121
+ 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0
122
+
123
+ Returns:
124
+ DataFrame: an output dataframe containing the samples of the dataset.
125
+ """
126
+ if extensions is None:
127
+ extensions = IMG_EXTENSIONS
128
+
129
+ root = validate_path(root)
130
+ samples_list = [(str(root),) + f.parts[-3:] for f in root.glob(r"**/*") if f.suffix in extensions]
131
+ if not samples_list:
132
+ msg = f"Found 0 images in {root}"
133
+ raise RuntimeError(msg)
134
+
135
+ samples = DataFrame(samples_list, columns=["path", "split", "label", "image_path"])
136
+
137
+ # Modify image_path column by converting to absolute path
138
+ samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path
139
+
140
+ # Create label index for normal (0) and anomalous (1) images.
141
+ samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
142
+ samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
143
+ samples.label_index = samples.label_index.astype(int)
144
+
145
+ # separate masks from samples
146
+ mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True)
147
+ samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True)
148
+
149
+ # assign mask paths to anomalous test images
150
+ samples["mask_path"] = ""
151
+ samples.loc[
152
+ (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL),
153
+ "mask_path",
154
+ ] = mask_samples.image_path.to_numpy()
155
+
156
+ # assert that the right mask files are associated with the right test images
157
+ abnormal_samples = samples.loc[samples.label_index == LabelName.ABNORMAL]
158
+ if (
159
+ len(abnormal_samples)
160
+ and not abnormal_samples.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1).all()
161
+ ):
162
+ msg = """Mismatch between anomalous images and ground truth masks. Make sure t
163
+ he mask files in 'ground_truth' folder follow the same naming convention as the
164
+ anomalous images in the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
165
+ raise MisMatchError(msg)
166
+
167
+ if split:
168
+ samples = samples[samples.split == split].reset_index(drop=True)
169
+
170
+ return samples
171
+
172
+
173
+ class MVTecDataset(AnomalibDataset):
174
+ """MVTec dataset class.
175
+
176
+ Args:
177
+ task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
178
+ root (Path | str): Path to the root of the dataset.
179
+ Defaults to ``./datasets/MVTec``.
180
+ category (str): Sub-category of the dataset, e.g. 'bottle'
181
+ Defaults to ``bottle``.
182
+ transform (Transform, optional): Transforms that should be applied to the input images.
183
+ Defaults to ``None``.
184
+ split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
185
+ Defaults to ``None``.
186
+
187
+ Examples:
188
+ .. code-block:: python
189
+
190
+ from anomalib.data.image.mvtec import MVTecDataset
191
+ from anomalib.data.utils.transforms import get_transforms
192
+
193
+ transform = get_transforms(image_size=256)
194
+ dataset = MVTecDataset(
195
+ task="classification",
196
+ transform=transform,
197
+ root='./datasets/MVTec',
198
+ category='zipper',
199
+ )
200
+ dataset.setup()
201
+ print(dataset[0].keys())
202
+ # Output: dict_keys(['image_path', 'label', 'image'])
203
+
204
+ When the task is segmentation, the dataset will also contain the mask:
205
+
206
+ .. code-block:: python
207
+
208
+ dataset.task = "segmentation"
209
+ dataset.setup()
210
+ print(dataset[0].keys())
211
+ # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
212
+
213
+ The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W).
214
+
215
+ .. code-block:: python
216
+
217
+ print(dataset[0]["image"].shape, dataset[0]["mask"].shape)
218
+ # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256]))
219
+
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ task: TaskType,
225
+ root: Path | str = "./datasets/MVTec",
226
+ category: str = "bottle",
227
+ transform: Transform | None = None,
228
+ split: str | Split | None = None,
229
+ ) -> None:
230
+ super().__init__(task=task, transform=transform)
231
+
232
+ self.root_category = Path(root) / Path(category)
233
+ self.category = category
234
+ self.split = split
235
+ self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS)
236
+
237
+
238
+ class MVTec(AnomalibDataModule):
239
+ """MVTec Datamodule.
240
+
241
+ Args:
242
+ root (Path | str): Path to the root of the dataset.
243
+ Defaults to ``"./datasets/MVTec"``.
244
+ category (str): Category of the MVTec dataset (e.g. "bottle" or "cable").
245
+ Defaults to ``"bottle"``.
246
+ train_batch_size (int, optional): Training batch size.
247
+ Defaults to ``32``.
248
+ eval_batch_size (int, optional): Test batch size.
249
+ Defaults to ``32``.
250
+ num_workers (int, optional): Number of workers.
251
+ Defaults to ``8``.
252
+ task TaskType): Task type, 'classification', 'detection' or 'segmentation'
253
+ Defaults to ``TaskType.SEGMENTATION``.
254
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
255
+ Defaults to ``None``.
256
+ transform (Transform, optional): Transforms that should be applied to the input images.
257
+ Defaults to ``None``.
258
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
259
+ Defaults to ``None``.
260
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
261
+ Defaults to ``None``.
262
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
263
+ Defaults to ``TestSplitMode.FROM_DIR``.
264
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
265
+ Defaults to ``0.2``.
266
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
267
+ Defaults to ``ValSplitMode.SAME_AS_TEST``.
268
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
269
+ Defaults to ``0.5``.
270
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
271
+ Defualts to ``None``.
272
+
273
+ Examples:
274
+ To create an MVTec AD datamodule with default settings:
275
+
276
+ >>> datamodule = MVTec()
277
+ >>> datamodule.setup()
278
+ >>> i, data = next(enumerate(datamodule.train_dataloader()))
279
+ >>> data.keys()
280
+ dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
281
+
282
+ >>> data["image"].shape
283
+ torch.Size([32, 3, 256, 256])
284
+
285
+ To change the category of the dataset:
286
+
287
+ >>> datamodule = MVTec(category="cable")
288
+
289
+ To change the image and batch size:
290
+
291
+ >>> datamodule = MVTec(image_size=(512, 512), train_batch_size=16, eval_batch_size=8)
292
+
293
+ MVTec AD dataset does not provide a validation set. If you would like
294
+ to use a separate validation set, you can use the ``val_split_mode`` and
295
+ ``val_split_ratio`` arguments to create a validation set.
296
+
297
+ >>> datamodule = MVTec(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1)
298
+
299
+ This will subsample the test set by 10% and use it as the validation set.
300
+ If you would like to create a validation set synthetically that would
301
+ not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option.
302
+
303
+ >>> datamodule = MVTec(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2)
304
+
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ root: Path | str = "./datasets/MVTec",
310
+ category: str = "bottle",
311
+ train_batch_size: int = 32,
312
+ eval_batch_size: int = 32,
313
+ num_workers: int = 8,
314
+ task: TaskType | str = TaskType.SEGMENTATION,
315
+ image_size: tuple[int, int] | None = None,
316
+ transform: Transform | None = None,
317
+ train_transform: Transform | None = None,
318
+ eval_transform: Transform | None = None,
319
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
320
+ test_split_ratio: float = 0.2,
321
+ val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
322
+ val_split_ratio: float = 0.5,
323
+ seed: int | None = None,
324
+ ) -> None:
325
+ super().__init__(
326
+ train_batch_size=train_batch_size,
327
+ eval_batch_size=eval_batch_size,
328
+ image_size=image_size,
329
+ transform=transform,
330
+ train_transform=train_transform,
331
+ eval_transform=eval_transform,
332
+ num_workers=num_workers,
333
+ test_split_mode=test_split_mode,
334
+ test_split_ratio=test_split_ratio,
335
+ val_split_mode=val_split_mode,
336
+ val_split_ratio=val_split_ratio,
337
+ seed=seed,
338
+ )
339
+
340
+ self.task = TaskType(task)
341
+ self.root = Path(root)
342
+ self.category = category
343
+
344
+ def _setup(self, _stage: str | None = None) -> None:
345
+ """Set up the datasets and perform dynamic subset splitting.
346
+
347
+ This method may be overridden in subclass for custom splitting behaviour.
348
+
349
+ Note:
350
+ The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule
351
+ subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset
352
+ splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and
353
+ the test set must therefore be created as early as the `fit` stage.
354
+
355
+ """
356
+ self.train_data = MVTecDataset(
357
+ task=self.task,
358
+ transform=self.train_transform,
359
+ split=Split.TRAIN,
360
+ root=self.root,
361
+ category=self.category,
362
+ )
363
+ self.test_data = MVTecDataset(
364
+ task=self.task,
365
+ transform=self.eval_transform,
366
+ split=Split.TEST,
367
+ root=self.root,
368
+ category=self.category,
369
+ )
370
+
371
+ def prepare_data(self) -> None:
372
+ """Download the dataset if not available.
373
+
374
+ This method checks if the specified dataset is available in the file system.
375
+ If not, it downloads and extracts the dataset into the appropriate directory.
376
+
377
+ Example:
378
+ Assume the dataset is not available on the file system.
379
+ Here's how the directory structure looks before and after calling the
380
+ `prepare_data` method:
381
+
382
+ Before:
383
+
384
+ .. code-block:: bash
385
+
386
+ $ tree datasets
387
+ datasets
388
+ ├── dataset1
389
+ └── dataset2
390
+
391
+ Calling the method:
392
+
393
+ .. code-block:: python
394
+
395
+ >> datamodule = MVTec(root="./datasets/MVTec", category="bottle")
396
+ >> datamodule.prepare_data()
397
+
398
+ After:
399
+
400
+ .. code-block:: bash
401
+
402
+ $ tree datasets
403
+ datasets
404
+ ├── dataset1
405
+ ├── dataset2
406
+ └── MVTec
407
+ ├── bottle
408
+ ├── ...
409
+ └── zipper
410
+ """
411
+ if (self.root / self.category).is_dir():
412
+ logger.info("Found the dataset.")
413
+ else:
414
+ download_and_extract(self.root, DOWNLOAD_INFO)
anomalib/data/image/mvtec_loco.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MVTec LOCO AD Dataset (CC BY-NC-SA 4.0).
2
+
3
+ Description:
4
+ This script contains PyTorch Dataset, Dataloader and PyTorch Lightning
5
+ DataModule for the MVTec LOCO AD dataset. If the dataset is not on the file system,
6
+ the script downloads and extracts the dataset and create PyTorch data objects.
7
+
8
+ License:
9
+ MVTec LOCO AD dataset is released under the Creative Commons
10
+ Attribution-NonCommercial-ShareAlike 4.0 International License
11
+ (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
12
+
13
+ References:
14
+ - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, and Carsten Steger:
15
+ Beyond Dents and Scratches: Logical Constraints in Unsupervised Anomaly Detection and Localization;
16
+ in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9
17
+ """
18
+
19
+ # Copyright (C) 2024 Intel Corporation
20
+ # SPDX-License-Identifier: Apache-2.0
21
+
22
+ import logging
23
+ from collections.abc import Sequence
24
+ from pathlib import Path
25
+
26
+ import torch
27
+ from pandas import DataFrame
28
+ from PIL import Image
29
+ from torchvision.transforms.v2 import Transform
30
+ from torchvision.transforms.v2.functional import to_image
31
+ from torchvision.tv_tensors import Mask
32
+
33
+ from anomalib import TaskType
34
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
35
+ from anomalib.data.utils import (
36
+ DownloadInfo,
37
+ LabelName,
38
+ Split,
39
+ TestSplitMode,
40
+ ValSplitMode,
41
+ download_and_extract,
42
+ masks_to_boxes,
43
+ read_image,
44
+ validate_path,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ IMG_EXTENSIONS = (".png", ".PNG")
51
+
52
+ DOWNLOAD_INFO = DownloadInfo(
53
+ name="mvtec_loco",
54
+ url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701"
55
+ "/mvtec_loco_anomaly_detection.tar.xz",
56
+ hashsum="9e7c84dba550fd2e59d8e9e231c929c45ba737b6b6a6d3814100f54d63aae687",
57
+ )
58
+
59
+ CATEGORIES = (
60
+ "breakfast_box",
61
+ "juice_bottle",
62
+ "pushpins",
63
+ "screw_bag",
64
+ "splicing_connectors",
65
+ )
66
+
67
+
68
+ def make_mvtec_loco_dataset(
69
+ root: str | Path,
70
+ split: str | Split | None = None,
71
+ extensions: Sequence[str] = IMG_EXTENSIONS,
72
+ ) -> DataFrame:
73
+ """Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure.
74
+
75
+ The files are expected to follow the structure:
76
+ path/to/dataset/split/category/image_filename.png
77
+ path/to/dataset/ground_truth/category/image_filename/000.png
78
+
79
+ where there can be multiple ground-truth masks for the corresponding anomalous images.
80
+
81
+ This function creates a dataframe to store the parsed information based on the following format:
82
+
83
+ +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+
84
+ | | path | split | label | image_path | mask_path | label_index |
85
+ +===+===============+=======+=========+===============+=======================================+=============+
86
+ | 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 |
87
+ +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+
88
+
89
+ Args:
90
+ root (str | Path): Path to dataset
91
+ split (str | Split | None): Dataset split (ie., either train or test).
92
+ Defaults to ``None``.
93
+ extensions (Sequence[str]): List of file extensions to be included in the dataset.
94
+ Defaults to ``None``.
95
+
96
+ Returns:
97
+ DataFrame: an output dataframe containing the samples of the dataset.
98
+
99
+ Examples:
100
+ The following example shows how to get test samples from MVTec LOCO AD pushpins category:
101
+
102
+ >>> root = Path('./MVTec_LOCO')
103
+ >>> category = 'pushpins'
104
+ >>> path = root / category
105
+ >>> samples = make_mvtec_loco_dataset(path, split='test')
106
+ """
107
+ root = validate_path(root)
108
+
109
+ # Retrieve the image and mask files
110
+ samples_list = []
111
+ for f in root.glob("**/*"):
112
+ if f.suffix in extensions:
113
+ parts = f.parts
114
+ # 'ground_truth' and non 'ground_truth' path have a different structure
115
+ if "ground_truth" not in parts:
116
+ split_folder, label_folder, image_file = parts[-3:]
117
+ image_path = f"{root}/{split_folder}/{label_folder}/{image_file}"
118
+ samples_list.append((str(root), split_folder, label_folder, "", image_path))
119
+ else:
120
+ split_folder, label_folder, image_folder, image_file = parts[-4:]
121
+ image_path = f"{root}/{split_folder}/{label_folder}/{image_folder}/{image_file}"
122
+ samples_list.append((str(root), split_folder, label_folder, image_folder, image_path))
123
+
124
+ if not samples_list:
125
+ msg = f"Found 0 images in {root}"
126
+ raise RuntimeError(msg)
127
+
128
+ samples = DataFrame(samples_list, columns=["path", "split", "label", "image_folder", "image_path"])
129
+
130
+ # Replace validation to Split.VAL.value in the split column
131
+ samples["split"] = samples["split"].replace("validation", Split.VAL.value)
132
+
133
+ # Create label index for normal (0) and anomalous (1) images.
134
+ samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
135
+ samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
136
+ samples.label_index = samples.label_index.astype(int)
137
+
138
+ # separate ground-truth masks from samples
139
+ mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True)
140
+ samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True)
141
+
142
+ # Group masks and aggregate the path into a list
143
+ mask_samples = (
144
+ mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"]
145
+ .agg(list)
146
+ .reset_index()
147
+ .rename(columns={"image_path": "mask_path"})
148
+ )
149
+
150
+ # assign mask paths to anomalous test images
151
+ samples["mask_path"] = ""
152
+ samples.loc[
153
+ (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL),
154
+ "mask_path",
155
+ ] = mask_samples.mask_path.to_numpy()
156
+
157
+ # validate that the right mask files are associated with the right test images
158
+ if len(samples.loc[samples.label_index == LabelName.ABNORMAL]):
159
+ image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem)
160
+ mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply(
161
+ lambda x: {Path(mask_path).parent.stem for mask_path in x},
162
+ )
163
+
164
+ if not all(
165
+ next(iter(mask_stems)) == image_stem
166
+ for image_stem, mask_stems in zip(image_stems, mask_parent_stems, strict=True)
167
+ ):
168
+ error_message = (
169
+ "Mismatch between anomalous images and ground truth masks. "
170
+ "Make sure the parent folder of the mask files in 'ground_truth' folder "
171
+ "follows the same naming convention as the anomalous images in the dataset "
172
+ "(e.g., image: '005.png', mask: '005/000.png')."
173
+ )
174
+ raise ValueError(error_message)
175
+
176
+ if split:
177
+ samples = samples[samples.split == split].reset_index(drop=True)
178
+
179
+ return samples
180
+
181
+
182
+ class MVTecLocoDataset(AnomalibDataset):
183
+ """MVTec LOCO dataset class.
184
+
185
+ Args:
186
+ task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
187
+ root (Path | str): Path to the root of the dataset.
188
+ Defaults to ``./datasets/MVTec_LOCO``.
189
+ category (str): Sub-category of the dataset, e.g. 'breakfast_box'
190
+ Defaults to ``breakfast_box``.
191
+ transform (Transform, optional): Transforms that should be applied to the input images.
192
+ Defaults to ``None``.
193
+ split (str | Split | None): Split of the dataset, Split.TRAIN, Split.VAL, or Split.TEST
194
+ Defaults to ``None``.
195
+
196
+ Examples:
197
+ .. code-block:: python
198
+
199
+ from anomalib.data.image.mvtec_loco import MVTecLocoDataset
200
+ from anomalib.data.utils.transforms import get_transforms
201
+ from torchvision.transforms.v2 import Resize
202
+
203
+ transform = Resize((256, 256))
204
+ dataset = MVTecLocoDataset(
205
+ task="classification",
206
+ transform=transform,
207
+ root='./datasets/MVTec_LOCO',
208
+ category='breakfast_box',
209
+ )
210
+ dataset.setup()
211
+ print(dataset[0].keys())
212
+ # Output: dict_keys(['image_path', 'label', 'image'])
213
+
214
+ When the task is segmentation, the dataset will also contain the mask:
215
+
216
+ .. code-block:: python
217
+
218
+ dataset.task = "segmentation"
219
+ dataset.setup()
220
+ print(dataset[0].keys())
221
+ # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
222
+
223
+ The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W).
224
+
225
+ .. code-block:: python
226
+
227
+ print(dataset[0]["image"].shape, dataset[0]["mask"].shape)
228
+ # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256]))
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ task: TaskType,
234
+ root: Path | str = "./datasets/MVTec_LOCO",
235
+ category: str = "breakfast_box",
236
+ transform: Transform | None = None,
237
+ split: str | Split | None = None,
238
+ ) -> None:
239
+ super().__init__(task=task, transform=transform)
240
+
241
+ self.root_category = Path(root) / category
242
+ self.split = split
243
+ self.samples = make_mvtec_loco_dataset(
244
+ self.root_category,
245
+ split=self.split,
246
+ extensions=IMG_EXTENSIONS,
247
+ )
248
+
249
+ @staticmethod
250
+ def _read_mask(mask_path: str | Path) -> Mask:
251
+ image = Image.open(mask_path).convert("L")
252
+ return Mask(to_image(image).squeeze(), dtype=torch.uint8)
253
+
254
+ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
255
+ """Get dataset item for the index ``index``.
256
+
257
+ This method is mostly based on the super class implementation, with some different as follows:
258
+ - Using 'torch.where' to make sure the 'mask' in the return item is binarized
259
+ - An additional 'masks' is added, the non-binary masks with original size for the SPRO metric calculation
260
+ Args:
261
+ index (int): Index to get the item.
262
+
263
+ Returns:
264
+ dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
265
+ target path, image tensor, label and transformed bounding box.
266
+ """
267
+ image_path = self.samples.iloc[index].image_path
268
+ mask_path = self.samples.iloc[index].mask_path
269
+ label_index = self.samples.iloc[index].label_index
270
+
271
+ image = read_image(image_path, as_tensor=True)
272
+ item = {"image_path": image_path, "label": label_index}
273
+
274
+ if self.task == TaskType.CLASSIFICATION:
275
+ item["image"] = self.transform(image) if self.transform else image
276
+ elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
277
+ # Only Anomalous (1) images have masks in anomaly datasets
278
+ # Therefore, create empty mask for Normal (0) images.
279
+ if isinstance(mask_path, str):
280
+ mask_path = [mask_path]
281
+ semantic_mask = (
282
+ Mask(torch.zeros(image.shape[-2:])).to(torch.uint8)
283
+ if label_index == LabelName.NORMAL
284
+ else Mask(torch.stack([self._read_mask(path) for path in mask_path]))
285
+ )
286
+
287
+ binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8))
288
+ item["image"], item["mask"] = self.transform(image, binary_mask) if self.transform else (image, binary_mask)
289
+
290
+ item["mask_path"] = mask_path
291
+ # List of masks with the original size for saturation based metrics calculation
292
+ item["semantic_mask"] = semantic_mask
293
+
294
+ if self.task == TaskType.DETECTION:
295
+ # create boxes from masks for detection task
296
+ boxes, _ = masks_to_boxes(item["mask"])
297
+ item["boxes"] = boxes[0]
298
+ else:
299
+ msg = f"Unknown task type: {self.task}"
300
+ raise ValueError(msg)
301
+
302
+ return item
303
+
304
+
305
+ class MVTecLoco(AnomalibDataModule):
306
+ """MVTec LOCO Datamodule.
307
+
308
+ Args:
309
+ root (Path | str): Path to the root of the dataset.
310
+ Defaults to ``"./datasets/MVTec_LOCO"``.
311
+ category (str): Category of the MVTec LOCO dataset (e.g. "breakfast_box").
312
+ Defaults to ``"breakfast_box"``.
313
+ train_batch_size (int, optional): Training batch size.
314
+ Defaults to ``32``.
315
+ eval_batch_size (int, optional): Test batch size.
316
+ Defaults to ``32``.
317
+ num_workers (int, optional): Number of workers.
318
+ Defaults to ``8``.
319
+ task TaskType): Task type, 'classification', 'detection' or 'segmentation'
320
+ Defaults to ``TaskType.SEGMENTATION``.
321
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
322
+ Defaults to ``None``.
323
+ transform (Transform, optional): Transforms that should be applied to the input images.
324
+ Defaults to ``None``.
325
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
326
+ Defaults to ``None``.
327
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
328
+ Defaults to ``None``.
329
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
330
+ Defaults to ``TestSplitMode.FROM_DIR``.
331
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
332
+ Defaults to ``0.2``.
333
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
334
+ Defaults to ``ValSplitMode.FROM_DIR``.
335
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
336
+ Defaults to ``0.5``.
337
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
338
+ Defaults to ``None``.
339
+
340
+ Examples:
341
+ To create an MVTec LOCO AD datamodule with default settings:
342
+
343
+ >>> datamodule = MVTecLoco(root="anomalib/datasets/MVTec_LOCO")
344
+ >>> datamodule.setup()
345
+ >>> i, data = next(enumerate(datamodule.train_dataloader()))
346
+ >>> data.keys()
347
+ dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
348
+
349
+ >>> data["image"].shape
350
+ torch.Size([32, 3, 256, 256])
351
+
352
+ To change the category of the dataset:
353
+
354
+ >>> datamodule = MVTecLoco(category="pushpins")
355
+
356
+ To change the image and batch size:
357
+
358
+ >>> datamodule = MVTecLoco(image_size=(512, 512), train_batch_size=16, eval_batch_size=8)
359
+
360
+ MVTec LOCO AD dataset provide an independent validation set with normal images only in the 'validation' folder.
361
+ If you would like to use a different validation set splitted from train or test set,
362
+ you can use the ``val_split_mode`` and ``val_split_ratio`` arguments to create a new validation set.
363
+
364
+ >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1)
365
+
366
+ This will subsample the test set by 10% and use it as the validation set.
367
+ If you would like to create a validation set synthetically that would
368
+ not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option.
369
+
370
+ >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2)
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ root: Path | str = "./datasets/MVTec_LOCO",
376
+ category: str = "breakfast_box",
377
+ train_batch_size: int = 32,
378
+ eval_batch_size: int = 32,
379
+ num_workers: int = 8,
380
+ task: TaskType = TaskType.SEGMENTATION,
381
+ image_size: tuple[int, int] | None = None,
382
+ transform: Transform | None = None,
383
+ train_transform: Transform | None = None,
384
+ eval_transform: Transform | None = None,
385
+ test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR,
386
+ test_split_ratio: float = 0.2,
387
+ val_split_mode: ValSplitMode = ValSplitMode.FROM_DIR,
388
+ val_split_ratio: float = 0.5,
389
+ seed: int | None = None,
390
+ ) -> None:
391
+ super().__init__(
392
+ train_batch_size=train_batch_size,
393
+ eval_batch_size=eval_batch_size,
394
+ image_size=image_size,
395
+ transform=transform,
396
+ train_transform=train_transform,
397
+ eval_transform=eval_transform,
398
+ num_workers=num_workers,
399
+ test_split_mode=test_split_mode,
400
+ test_split_ratio=test_split_ratio,
401
+ val_split_mode=val_split_mode,
402
+ val_split_ratio=val_split_ratio,
403
+ seed=seed,
404
+ )
405
+ self.task = task
406
+ self.root = Path(root)
407
+ self.category = category
408
+
409
+ def _setup(self, _stage: str | None = None) -> None:
410
+ """Set up the datasets, configs, and perform dynamic subset splitting.
411
+
412
+ This method overrides the parent class's method to also setup the val dataset.
413
+ The MVTec LOCO dataset provides an independent validation subset.
414
+ """
415
+ self.train_data = MVTecLocoDataset(
416
+ task=self.task,
417
+ transform=self.train_transform,
418
+ split=Split.TRAIN,
419
+ root=self.root,
420
+ category=self.category,
421
+ )
422
+ self.val_data = MVTecLocoDataset(
423
+ task=self.task,
424
+ transform=self.eval_transform,
425
+ split=Split.VAL,
426
+ root=self.root,
427
+ category=self.category,
428
+ )
429
+ self.test_data = MVTecLocoDataset(
430
+ task=self.task,
431
+ transform=self.eval_transform,
432
+ split=Split.TEST,
433
+ root=self.root,
434
+ category=self.category,
435
+ )
436
+
437
+ def prepare_data(self) -> None:
438
+ """Download the dataset if not available.
439
+
440
+ This method checks if the specified dataset is available in the file system.
441
+ If not, it downloads and extracts the dataset into the appropriate directory.
442
+
443
+ Example:
444
+ Assume the dataset is not available on the file system.
445
+ Here's how the directory structure looks before and after calling the
446
+ `prepare_data` method:
447
+
448
+ Before:
449
+
450
+ .. code-block:: bash
451
+
452
+ $ tree datasets
453
+ datasets
454
+ ├── dataset1
455
+ └── dataset2
456
+
457
+ Calling the method:
458
+
459
+ .. code-block:: python
460
+
461
+ >> datamodule = MVTecLoco(root="./datasets/MVTec_LOCO", category="breakfast_box")
462
+ >> datamodule.prepare_data()
463
+
464
+ After:
465
+
466
+ .. code-block:: bash
467
+
468
+ $ tree datasets
469
+ datasets
470
+ ├── dataset1
471
+ ├── dataset2
472
+ └── MVTec_LOCO
473
+ ├── breakfast_box
474
+ ├── ...
475
+ └── splicing_connectors
476
+ """
477
+ if (self.root / self.category).is_dir():
478
+ logger.info("Found the dataset.")
479
+ else:
480
+ download_and_extract(self.root, DOWNLOAD_INFO)
anomalib/data/image/visa.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual Anomaly (VisA) Dataset (CC BY-NC-SA 4.0).
2
+
3
+ Description:
4
+ This script contains PyTorch Dataset, Dataloader and PyTorch
5
+ Lightning DataModule for the Visual Anomal (VisA) dataset.
6
+ If the dataset is not on the file system, the script downloads and
7
+ extracts the dataset and create PyTorch data objects.
8
+ License:
9
+ The VisA dataset is released under the Creative Commons
10
+ Attribution-NonCommercial-ShareAlike 4.0 International License
11
+ (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
12
+ Reference:
13
+ - Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). SPot-the-Difference
14
+ Self-supervised Pre-training for Anomaly Detection and Segmentation. In European
15
+ Conference on Computer Vision (pp. 392-408). Springer, Cham.
16
+ """
17
+
18
+ # Copyright (C) 2022-2024 Intel Corporation
19
+ # SPDX-License-Identifier: Apache-2.0
20
+
21
+ # Subset splitting code adapted from https://github.com/amazon-science/spot-diff
22
+ # Original licence: Apache-2.0
23
+
24
+
25
+ import csv
26
+ import logging
27
+ import shutil
28
+ from pathlib import Path
29
+
30
+ import cv2
31
+ from torchvision.transforms.v2 import Transform
32
+
33
+ from anomalib import TaskType
34
+ from anomalib.data.base import AnomalibDataModule, AnomalibDataset
35
+ from anomalib.data.utils import (
36
+ DownloadInfo,
37
+ Split,
38
+ TestSplitMode,
39
+ ValSplitMode,
40
+ download_and_extract,
41
+ )
42
+
43
+ from .mvtec import make_mvtec_dataset
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ EXTENSIONS = (".png", ".jpg", ".JPG")
48
+
49
+ DOWNLOAD_INFO = DownloadInfo(
50
+ name="VisA",
51
+ url="https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar",
52
+ hashsum="2eb8690c803ab37de0324772964100169ec8ba1fa3f7e94291c9ca673f40f362",
53
+ )
54
+
55
+ CATEGORIES = (
56
+ "candle",
57
+ "capsules",
58
+ "cashew",
59
+ "chewinggum",
60
+ "fryum",
61
+ "macaroni1",
62
+ "macaroni2",
63
+ "pcb1",
64
+ "pcb2",
65
+ "pcb3",
66
+ "pcb4",
67
+ "pipe_fryum",
68
+ )
69
+
70
+
71
+ class VisaDataset(AnomalibDataset):
72
+ """VisA dataset class.
73
+
74
+ Args:
75
+ task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
76
+ root (str | Path): Path to the root of the dataset
77
+ category (str): Sub-category of the dataset, e.g. 'candle'
78
+ transform (Transform, optional): Transforms that should be applied to the input images.
79
+ Defaults to ``None``.
80
+ split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
81
+ Defaults to ``None``.
82
+
83
+ Examples:
84
+ To create a Visa dataset for classification:
85
+
86
+ .. code-block:: python
87
+
88
+ from anomalib.data.image.visa import VisaDataset
89
+ from anomalib.data.utils.transforms import get_transforms
90
+
91
+ transform = get_transforms(image_size=256)
92
+ dataset = VisaDataset(
93
+ task="classification",
94
+ transform=transform,
95
+ split="train",
96
+ root="./datasets/visa/visa_pytorch/",
97
+ category="candle",
98
+ )
99
+ dataset.setup()
100
+ dataset[0].keys()
101
+
102
+ # Output
103
+ dict_keys(['image_path', 'label', 'image'])
104
+
105
+ If you want to use the dataset for segmentation, you can use the same
106
+ code as above, with the task set to ``segmentation``. The dataset will
107
+ then have a ``mask`` key in the output dictionary.
108
+
109
+ .. code-block:: python
110
+
111
+ from anomalib.data.image.visa import VisaDataset
112
+ from anomalib.data.utils.transforms import get_transforms
113
+
114
+ transform = get_transforms(image_size=256)
115
+ dataset = VisaDataset(
116
+ task="segmentation",
117
+ transform=transform,
118
+ split="train",
119
+ root="./datasets/visa/visa_pytorch/",
120
+ category="candle",
121
+ )
122
+ dataset.setup()
123
+ dataset[0].keys()
124
+
125
+ # Output
126
+ dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
127
+
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ task: TaskType,
133
+ root: str | Path,
134
+ category: str,
135
+ transform: Transform | None = None,
136
+ split: str | Split | None = None,
137
+ ) -> None:
138
+ super().__init__(task=task, transform=transform)
139
+
140
+ self.root_category = Path(root) / category
141
+ self.split = split
142
+ self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=EXTENSIONS)
143
+
144
+
145
+ class Visa(AnomalibDataModule):
146
+ """VisA Datamodule.
147
+
148
+ Args:
149
+ root (Path | str): Path to the root of the dataset
150
+ Defaults to ``"./datasets/visa"``.
151
+ category (str): Category of the Visa dataset such as ``candle``.
152
+ Defaults to ``"candle"``.
153
+ train_batch_size (int, optional): Training batch size.
154
+ Defaults to ``32``.
155
+ eval_batch_size (int, optional): Test batch size.
156
+ Defaults to ``32``.
157
+ num_workers (int, optional): Number of workers.
158
+ Defaults to ``8``.
159
+ task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
160
+ Defaults to ``TaskType.SEGMENTATION``.
161
+ image_size (tuple[int, int], optional): Size to which input images should be resized.
162
+ Defaults to ``None``.
163
+ transform (Transform, optional): Transforms that should be applied to the input images.
164
+ Defaults to ``None``.
165
+ train_transform (Transform, optional): Transforms that should be applied to the input images during training.
166
+ Defaults to ``None``.
167
+ eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
168
+ Defaults to ``None``.
169
+ test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
170
+ Defaults to ``TestSplitMode.FROM_DIR``.
171
+ test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
172
+ Defaults to ``0.2``.
173
+ val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
174
+ Defaults to ``ValSplitMode.SAME_AS_TEST``.
175
+ val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
176
+ Defatuls to ``0.5``.
177
+ seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
178
+ Defaults to ``None``.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ root: Path | str = "./datasets/visa",
184
+ category: str = "capsules",
185
+ train_batch_size: int = 32,
186
+ eval_batch_size: int = 32,
187
+ num_workers: int = 8,
188
+ task: TaskType | str = TaskType.SEGMENTATION,
189
+ image_size: tuple[int, int] | None = None,
190
+ transform: Transform | None = None,
191
+ train_transform: Transform | None = None,
192
+ eval_transform: Transform | None = None,
193
+ test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
194
+ test_split_ratio: float = 0.2,
195
+ val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
196
+ val_split_ratio: float = 0.5,
197
+ seed: int | None = None,
198
+ ) -> None:
199
+ super().__init__(
200
+ train_batch_size=train_batch_size,
201
+ eval_batch_size=eval_batch_size,
202
+ num_workers=num_workers,
203
+ image_size=image_size,
204
+ transform=transform,
205
+ train_transform=train_transform,
206
+ eval_transform=eval_transform,
207
+ test_split_mode=test_split_mode,
208
+ test_split_ratio=test_split_ratio,
209
+ val_split_mode=val_split_mode,
210
+ val_split_ratio=val_split_ratio,
211
+ seed=seed,
212
+ )
213
+
214
+ self.task = TaskType(task)
215
+ self.root = Path(root)
216
+ self.split_root = self.root / "visa_pytorch"
217
+ self.category = category
218
+
219
+ def _setup(self, _stage: str | None = None) -> None:
220
+ self.train_data = VisaDataset(
221
+ task=self.task,
222
+ transform=self.train_transform,
223
+ split=Split.TRAIN,
224
+ root=self.split_root,
225
+ category=self.category,
226
+ )
227
+ self.test_data = VisaDataset(
228
+ task=self.task,
229
+ transform=self.eval_transform,
230
+ split=Split.TEST,
231
+ root=self.split_root,
232
+ category=self.category,
233
+ )
234
+
235
+ def prepare_data(self) -> None:
236
+ """Download the dataset if not available.
237
+
238
+ This method checks if the specified dataset is available in the file system.
239
+ If not, it downloads and extracts the dataset into the appropriate directory.
240
+
241
+ Example:
242
+ Assume the dataset is not available on the file system.
243
+ Here's how the directory structure looks before and after calling the
244
+ `prepare_data` method:
245
+
246
+ Before:
247
+
248
+ .. code-block:: bash
249
+
250
+ $ tree datasets
251
+ datasets
252
+ ├── dataset1
253
+ └── dataset2
254
+
255
+ Calling the method:
256
+
257
+ .. code-block:: python
258
+
259
+ >> datamodule = Visa()
260
+ >> datamodule.prepare_data()
261
+
262
+ After:
263
+
264
+ .. code-block:: bash
265
+
266
+ $ tree datasets
267
+ datasets
268
+ ├── dataset1
269
+ ├── dataset2
270
+ └── visa
271
+ ├── candle
272
+ ├── ...
273
+ ├── pipe_fryum
274
+ │ ├── Data
275
+ │ └── image_anno.csv
276
+ ├── split_csv
277
+ │ ├── 1cls.csv
278
+ │ ├── 2cls_fewshot.csv
279
+ │ └── 2cls_highshot.csv
280
+ ├── VisA_20220922.tar
281
+ └── visa_pytorch
282
+ ├── candle
283
+ ├── ...
284
+ ├── pcb4
285
+ └── pipe_fryum
286
+
287
+ ``prepare_data`` ensures that the dataset is converted to MVTec
288
+ format. ``visa_pytorch`` is the directory that contains the dataset
289
+ in the MVTec format. ``visa`` is the directory that contains the
290
+ original dataset.
291
+ """
292
+ if (self.split_root / self.category).is_dir():
293
+ # dataset is available, and split has been applied
294
+ logger.info("Found the dataset and train/test split.")
295
+ elif (self.root / self.category).is_dir():
296
+ # dataset is available, but split has not yet been applied
297
+ logger.info("Found the dataset. Applying train/test split.")
298
+ self.apply_cls1_split()
299
+ else:
300
+ # dataset is not available
301
+ download_and_extract(self.root, DOWNLOAD_INFO)
302
+ logger.info("Downloaded the dataset. Applying train/test split.")
303
+ self.apply_cls1_split()
304
+
305
+ def apply_cls1_split(self) -> None:
306
+ """Apply the 1-class subset splitting using the fixed split in the csv file.
307
+
308
+ adapted from https://github.com/amazon-science/spot-diff
309
+ """
310
+ logger.info("preparing data")
311
+ categories = [
312
+ "candle",
313
+ "capsules",
314
+ "cashew",
315
+ "chewinggum",
316
+ "fryum",
317
+ "macaroni1",
318
+ "macaroni2",
319
+ "pcb1",
320
+ "pcb2",
321
+ "pcb3",
322
+ "pcb4",
323
+ "pipe_fryum",
324
+ ]
325
+
326
+ split_file = self.root / "split_csv" / "1cls.csv"
327
+
328
+ for category in categories:
329
+ train_folder = self.split_root / category / "train"
330
+ test_folder = self.split_root / category / "test"
331
+ mask_folder = self.split_root / category / "ground_truth"
332
+
333
+ train_img_good_folder = train_folder / "good"
334
+ test_img_good_folder = test_folder / "good"
335
+ test_img_bad_folder = test_folder / "bad"
336
+ test_mask_bad_folder = mask_folder / "bad"
337
+
338
+ train_img_good_folder.mkdir(parents=True, exist_ok=True)
339
+ test_img_good_folder.mkdir(parents=True, exist_ok=True)
340
+ test_img_bad_folder.mkdir(parents=True, exist_ok=True)
341
+ test_mask_bad_folder.mkdir(parents=True, exist_ok=True)
342
+
343
+ with split_file.open(encoding="utf-8") as file:
344
+ csvreader = csv.reader(file)
345
+ next(csvreader)
346
+ for row in csvreader:
347
+ category, split, label, image_path, mask_path = row
348
+ label = "good" if label == "normal" else "bad"
349
+ image_name = image_path.split("/")[-1]
350
+ mask_name = mask_path.split("/")[-1]
351
+
352
+ img_src_path = self.root / image_path
353
+ msk_src_path = self.root / mask_path
354
+ img_dst_path = self.split_root / category / split / label / image_name
355
+ msk_dst_path = self.split_root / category / "ground_truth" / label / mask_name
356
+
357
+ shutil.copyfile(img_src_path, img_dst_path)
358
+ if split == "test" and label == "bad":
359
+ mask = cv2.imread(str(msk_src_path))
360
+
361
+ # binarize mask
362
+ mask[mask != 0] = 255
363
+
364
+ cv2.imwrite(str(msk_dst_path), mask)
anomalib/data/predict.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference Dataset."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from torch.utils.data.dataset import Dataset
11
+ from torchvision.transforms.v2 import Transform
12
+
13
+ from anomalib.data.utils import get_image_filenames, read_image
14
+
15
+
16
+ class PredictDataset(Dataset):
17
+ """Inference Dataset to perform prediction.
18
+
19
+ Args:
20
+ path (str | Path): Path to an image or image-folder.
21
+ transform (A.Compose | None, optional): Transform object describing the transforms that are
22
+ applied to the inputs.
23
+ image_size (int | tuple[int, int] | None, optional): Target image size
24
+ to resize the original image. Defaults to None.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ path: str | Path,
30
+ transform: Transform | None = None,
31
+ image_size: int | tuple[int, int] = (256, 256),
32
+ ) -> None:
33
+ super().__init__()
34
+
35
+ self.image_filenames = get_image_filenames(path)
36
+ self.transform = transform
37
+ self.image_size = image_size
38
+
39
+ def __len__(self) -> int:
40
+ """Get the number of images in the given path."""
41
+ return len(self.image_filenames)
42
+
43
+ def __getitem__(self, index: int) -> dict[str, Any]:
44
+ """Get the image based on the `index`."""
45
+ image_filename = self.image_filenames[index]
46
+ image = read_image(image_filename, as_tensor=True)
47
+ if self.transform:
48
+ image = self.transform(image)
49
+ pre_processed = {"image": image}
50
+ pre_processed["image_path"] = str(image_filename)
51
+
52
+ return pre_processed
anomalib/data/transforms/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Custom input transforms for Anomalib."""
2
+
3
+ # Copyright (C) 2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from .center_crop import ExportableCenterCrop
7
+
8
+ __all__ = ["ExportableCenterCrop"]
anomalib/data/transforms/center_crop.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Torchvision transforms for Anomalib."""
2
+
3
+ # Original Code
4
+ # Copyright (c) Soumith Chintala 2016
5
+ # https://github.com/pytorch/vision/blob/v0.16.1/torchvision/transforms/v2/functional/_geometry.py
6
+ # SPDX-License-Identifier: BSD-3-Clause
7
+ #
8
+ # Modified
9
+ # Copyright (C) 2024 Intel Corporation
10
+ # SPDX-License-Identifier: Apache-2.0
11
+
12
+ from typing import Any
13
+
14
+ import torch
15
+ from torch.nn.functional import pad
16
+ from torchvision.transforms.v2 import Transform
17
+ from torchvision.transforms.v2.functional._geometry import (
18
+ _center_crop_compute_padding,
19
+ _center_crop_parse_output_size,
20
+ _parse_pad_padding,
21
+ )
22
+
23
+
24
+ def _center_crop_compute_crop_anchor(
25
+ crop_height: int,
26
+ crop_width: int,
27
+ image_height: int,
28
+ image_width: int,
29
+ ) -> tuple[int, int]:
30
+ """Compute the anchor point for center-cropping.
31
+
32
+ This function is a modified version of the torchvision.transforms.functional._center_crop_compute_crop_anchor
33
+ function. The original function uses `round` to compute the anchor point, which is not compatible with ONNX.
34
+
35
+ Args:
36
+ crop_height (int): Desired height of the crop.
37
+ crop_width (int): Desired width of the crop.
38
+ image_height (int): Height of the input image.
39
+ image_width (int): Width of the input image.
40
+ """
41
+ crop_top = torch.tensor((image_height - crop_height) / 2.0).round().int().item()
42
+ crop_left = torch.tensor((image_width - crop_width) / 2.0).round().int().item()
43
+ return crop_top, crop_left
44
+
45
+
46
+ def center_crop_image(image: torch.Tensor, output_size: list[int]) -> torch.Tensor:
47
+ """Apply center-cropping to an input image.
48
+
49
+ Uses the modified anchor point computation function to compute the anchor point for center-cropping.
50
+
51
+ Args:
52
+ image (torch.Tensor): Input image to be center-cropped.
53
+ output_size (list[int]): Desired output size of the crop.
54
+ """
55
+ crop_height, crop_width = _center_crop_parse_output_size(output_size)
56
+ shape = image.shape
57
+ if image.numel() == 0:
58
+ return image.reshape(shape[:-2] + (crop_height, crop_width))
59
+ image_height, image_width = shape[-2:]
60
+
61
+ if crop_height > image_height or crop_width > image_width:
62
+ padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
63
+ image = pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
64
+
65
+ image_height, image_width = image.shape[-2:]
66
+ if crop_width == image_width and crop_height == image_height:
67
+ return image
68
+
69
+ crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
70
+ return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
71
+
72
+
73
+ class ExportableCenterCrop(Transform):
74
+ """Transform that applies center-cropping to an input image and allows to be exported to ONNX.
75
+
76
+ Args:
77
+ size (int | tuple[int, int]): Desired output size of the crop.
78
+ """
79
+
80
+ def __init__(self, size: int | tuple[int, int]) -> None:
81
+ super().__init__()
82
+ self.size = list(size) if isinstance(size, tuple) else [size, size]
83
+
84
+ def _transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor:
85
+ """Apply the transform."""
86
+ del params
87
+ return center_crop_image(inpt, output_size=self.size)
anomalib/data/utils/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper utilities for data."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from .augmenter import Augmenter
7
+ from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
8
+ from .download import DownloadInfo, download_and_extract
9
+ from .generators import random_2d_perlin
10
+ from .image import (
11
+ generate_output_image_filename,
12
+ get_image_filenames,
13
+ get_image_height_and_width,
14
+ read_depth_image,
15
+ read_image,
16
+ read_mask,
17
+ )
18
+ from .label import LabelName
19
+ from .path import (
20
+ DirType,
21
+ _check_and_convert_path,
22
+ _prepare_files_labels,
23
+ resolve_path,
24
+ validate_and_resolve_path,
25
+ validate_path,
26
+ )
27
+ from .split import Split, TestSplitMode, ValSplitMode, concatenate_datasets, random_split, split_by_label
28
+
29
+ __all__ = [
30
+ "generate_output_image_filename",
31
+ "get_image_filenames",
32
+ "get_image_height_and_width",
33
+ "random_2d_perlin",
34
+ "read_image",
35
+ "read_mask",
36
+ "read_depth_image",
37
+ "random_split",
38
+ "split_by_label",
39
+ "concatenate_datasets",
40
+ "Split",
41
+ "ValSplitMode",
42
+ "TestSplitMode",
43
+ "LabelName",
44
+ "DirType",
45
+ "Augmenter",
46
+ "masks_to_boxes",
47
+ "boxes_to_masks",
48
+ "boxes_to_anomaly_maps",
49
+ "download_and_extract",
50
+ "DownloadInfo",
51
+ "_check_and_convert_path",
52
+ "_prepare_files_labels",
53
+ "resolve_path",
54
+ "validate_path",
55
+ "validate_and_resolve_path",
56
+ ]
anomalib/data/utils/augmenter.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Augmenter module to generates out-of-distribution samples for the DRAEM implementation."""
2
+
3
+ # Original Code
4
+ # Copyright (c) 2021 VitjanZ
5
+ # https://github.com/VitjanZ/DRAEM.
6
+ # SPDX-License-Identifier: MIT
7
+ #
8
+ # Modified
9
+ # Copyright (C) 2022-2024 Intel Corporation
10
+ # SPDX-License-Identifier: Apache-2.0
11
+
12
+
13
+ import math
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import cv2
18
+ import imgaug.augmenters as iaa
19
+ import numpy as np
20
+ import torch
21
+ from PIL import Image
22
+ from torchvision.datasets.folder import IMG_EXTENSIONS
23
+
24
+ from anomalib.data.utils.generators.perlin import random_2d_perlin
25
+
26
+
27
+ def nextpow2(value: int) -> int:
28
+ """Return the smallest power of 2 greater than or equal to the input value."""
29
+ return 2 ** (math.ceil(math.log(value, 2)))
30
+
31
+
32
+ class Augmenter:
33
+ """Class that generates noisy augmentations of input images.
34
+
35
+ Args:
36
+ anomaly_source_path (str | None): Path to a folder of images that will be used as source of the anomalous
37
+ noise. If not specified, random noise will be used instead.
38
+ p_anomalous (float): Probability that the anomalous perturbation will be applied to a given image.
39
+ beta (float): Parameter that determines the opacity of the noise mask.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ anomaly_source_path: str | None = None,
45
+ p_anomalous: float = 0.5,
46
+ beta: float | tuple[float, float] = (0.2, 1.0),
47
+ ) -> None:
48
+ self.p_anomalous = p_anomalous
49
+ self.beta = beta
50
+
51
+ self.anomaly_source_paths: list[Path] = []
52
+ if anomaly_source_path is not None:
53
+ for img_ext in IMG_EXTENSIONS:
54
+ self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext))
55
+
56
+ self.augmenters = [
57
+ iaa.GammaContrast((0.5, 2.0), per_channel=True),
58
+ iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)),
59
+ iaa.pillike.EnhanceSharpness(),
60
+ iaa.AddToHueAndSaturation((-50, 50), per_channel=True),
61
+ iaa.Solarize(0.5, threshold=(32, 128)),
62
+ iaa.Posterize(),
63
+ iaa.Invert(),
64
+ iaa.pillike.Autocontrast(),
65
+ iaa.pillike.Equalize(),
66
+ iaa.Affine(rotate=(-45, 45)),
67
+ ]
68
+ self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
69
+
70
+ def rand_augmenter(self) -> iaa.Sequential:
71
+ """Select 3 random transforms that will be applied to the anomaly source images.
72
+
73
+ Returns:
74
+ A selection of 3 transforms.
75
+ """
76
+ aug_ind = np.random.default_rng().choice(np.arange(len(self.augmenters)), 3, replace=False)
77
+ return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]])
78
+
79
+ def generate_perturbation(
80
+ self,
81
+ height: int,
82
+ width: int,
83
+ anomaly_source_path: Path | str | None = None,
84
+ ) -> tuple[np.ndarray, np.ndarray]:
85
+ """Generate an image containing a random anomalous perturbation using a source image.
86
+
87
+ Args:
88
+ height (int): height of the generated image.
89
+ width: (int): width of the generated image.
90
+ anomaly_source_path (Path | str | None): Path to an image file. If not provided, random noise will be used
91
+ instead.
92
+
93
+ Returns:
94
+ Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask.
95
+ """
96
+ # Generate random perlin noise
97
+ perlin_scale = 6
98
+ min_perlin_scale = 0
99
+
100
+ perlin_scalex = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale)
101
+ perlin_scaley = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale)
102
+
103
+ perlin_noise = random_2d_perlin((nextpow2(height), nextpow2(width)), (perlin_scalex, perlin_scaley))[
104
+ :height,
105
+ :width,
106
+ ]
107
+ perlin_noise = self.rot(image=perlin_noise)
108
+
109
+ # Create mask from perlin noise
110
+ mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
111
+ mask = np.expand_dims(mask, axis=2).astype(np.float32)
112
+
113
+ # Load anomaly source image
114
+ if anomaly_source_path:
115
+ anomaly_source_img = np.array(Image.open(anomaly_source_path))
116
+ anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height))
117
+ else: # if no anomaly source is specified, we use the perlin noise as anomalous source
118
+ anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2)
119
+ anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8)
120
+
121
+ # Augment anomaly source image
122
+ aug = self.rand_augmenter()
123
+ anomaly_img_augmented = aug(image=anomaly_source_img)
124
+
125
+ # Create anomalous perturbation that we will apply to the image
126
+ perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0
127
+
128
+ return perturbation, mask
129
+
130
+ def augment_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
131
+ """Generate anomalous augmentations for a batch of input images.
132
+
133
+ Args:
134
+ batch (torch.Tensor): Batch of input images
135
+
136
+ Returns:
137
+ - Augmented image to which anomalous perturbations have been added.
138
+ - Ground truth masks corresponding to the anomalous perturbations.
139
+ """
140
+ batch_size, channels, height, width = batch.shape
141
+
142
+ # Collect perturbations
143
+ perturbations_list = []
144
+ masks_list = []
145
+ for _ in range(batch_size):
146
+ if torch.rand(1) > self.p_anomalous: # include normal samples
147
+ perturbations_list.append(torch.zeros((channels, height, width)))
148
+ masks_list.append(torch.zeros((1, height, width)))
149
+ else:
150
+ anomaly_source_path = (
151
+ random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None
152
+ )
153
+ perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path)
154
+ perturbations_list.append(torch.Tensor(perturbation).permute((2, 0, 1)))
155
+ masks_list.append(torch.Tensor(mask).permute((2, 0, 1)))
156
+
157
+ perturbations = torch.stack(perturbations_list).to(batch.device)
158
+ masks = torch.stack(masks_list).to(batch.device)
159
+
160
+ # Apply perturbations batch wise
161
+ if isinstance(self.beta, float):
162
+ beta = self.beta
163
+ elif isinstance(self.beta, tuple):
164
+ beta = torch.rand(batch_size) * (self.beta[1] - self.beta[0]) + self.beta[0]
165
+ beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) # type: ignore[attr-defined]
166
+ else:
167
+ msg = "Beta must be either float or tuple of floats"
168
+ raise TypeError(msg)
169
+
170
+ augmented_batch = batch * (1 - masks) + (beta) * perturbations + (1 - beta) * batch * (masks)
171
+
172
+ return augmented_batch, masks
anomalib/data/utils/boxes.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper functions for processing bounding box detections and annotations."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import torch
8
+
9
+ from anomalib.utils.cv import connected_components_cpu, connected_components_gpu
10
+
11
+
12
+ def masks_to_boxes(
13
+ masks: torch.Tensor,
14
+ anomaly_maps: torch.Tensor | None = None,
15
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
16
+ """Convert a batch of segmentation masks to bounding box coordinates.
17
+
18
+ Args:
19
+ masks (torch.Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W)
20
+ anomaly_maps (Tensor | None, optional): Anomaly maps of shape (B, 1, H, W), (B, H, W) or (H, W) which are
21
+ used to determine an anomaly score for the converted bounding boxes.
22
+
23
+ Returns:
24
+ list[torch.Tensor]: A list of length B where each element is a tensor of shape (N, 4)
25
+ containing the bounding box coordinates of the objects in the masks in xyxy format.
26
+ list[torch.Tensor]: A list of length B where each element is a tensor of length (N)
27
+ containing an anomaly score for each of the converted boxes.
28
+ """
29
+ height, width = masks.shape[-2:]
30
+ masks = masks.view((-1, 1, height, width)).float() # reshape to (B, 1, H, W) and cast to float
31
+ if anomaly_maps is not None:
32
+ anomaly_maps = anomaly_maps.view((-1,) + masks.shape[-2:])
33
+
34
+ if masks.is_cpu:
35
+ batch_comps = connected_components_cpu(masks).squeeze(1)
36
+ else:
37
+ batch_comps = connected_components_gpu(masks).squeeze(1)
38
+
39
+ batch_boxes = []
40
+ batch_scores = []
41
+ for im_idx, im_comps in enumerate(batch_comps):
42
+ labels = torch.unique(im_comps)
43
+ im_boxes = []
44
+ im_scores = []
45
+ for label in labels[labels != 0]:
46
+ y_loc, x_loc = torch.where(im_comps == label)
47
+ # add box
48
+ box = torch.Tensor([torch.min(x_loc), torch.min(y_loc), torch.max(x_loc), torch.max(y_loc)]).to(
49
+ masks.device,
50
+ )
51
+ im_boxes.append(box)
52
+ if anomaly_maps is not None:
53
+ im_scores.append(torch.max(anomaly_maps[im_idx, y_loc, x_loc]))
54
+ batch_boxes.append(torch.stack(im_boxes) if im_boxes else torch.empty((0, 4), device=masks.device))
55
+ batch_scores.append(torch.stack(im_scores) if im_scores else torch.empty(0, device=masks.device))
56
+
57
+ return batch_boxes, batch_scores
58
+
59
+
60
+ def boxes_to_masks(boxes: list[torch.Tensor], image_size: tuple[int, int]) -> torch.Tensor:
61
+ """Convert bounding boxes to segmentations masks.
62
+
63
+ Args:
64
+ boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4)
65
+ containing the bounding box coordinates of the regions of interest in xyxy format.
66
+ image_size (tuple[int, int]): Image size of the output masks in (H, W) format.
67
+
68
+ Returns:
69
+ Tensor: torch.Tensor of shape (B, H, W) in which each slice is a binary mask showing the pixels contained by a
70
+ bounding box.
71
+ """
72
+ masks = torch.zeros((len(boxes), *image_size)).to(boxes[0].device)
73
+ for im_idx, im_boxes in enumerate(boxes):
74
+ for box in im_boxes:
75
+ x_1, y_1, x_2, y_2 = box.int()
76
+ masks[im_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = 1
77
+ return masks
78
+
79
+
80
+ def boxes_to_anomaly_maps(boxes: torch.Tensor, scores: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
81
+ """Convert bounding box coordinates to anomaly heatmaps.
82
+
83
+ Args:
84
+ boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4)
85
+ containing the bounding box coordinates of the regions of interest in xyxy format.
86
+ scores (list[torch.Tensor]): A list of length B where each element is a 1D tensor of length N
87
+ containing the anomaly scores for each region of interest.
88
+ image_size (tuple[int, int]): Image size of the output masks in (H, W) format.
89
+
90
+ Returns:
91
+ Tensor: torch.Tensor of shape (B, H, W). The pixel locations within each bounding box are collectively
92
+ assigned the anomaly score of the bounding box. In the case of overlapping bounding boxes,
93
+ the highest score is used.
94
+ """
95
+ anomaly_maps = torch.zeros((len(boxes), *image_size)).to(boxes[0].device)
96
+ for im_idx, (im_boxes, im_scores) in enumerate(zip(boxes, scores, strict=False)):
97
+ im_map = torch.zeros((im_boxes.shape[0], *image_size))
98
+ for box_idx, (box, score) in enumerate(zip(im_boxes, im_scores, strict=True)):
99
+ x_1, y_1, x_2, y_2 = box.int()
100
+ im_map[box_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = score
101
+ anomaly_maps[im_idx], _ = im_map.max(dim=0)
102
+ return anomaly_maps
103
+
104
+
105
+ def scale_boxes(boxes: torch.Tensor, image_size: torch.Size, new_size: torch.Size) -> torch.Tensor:
106
+ """Scale bbox coordinates to a new image size.
107
+
108
+ Args:
109
+ boxes (torch.Tensor): Boxes of shape (N, 4) - (x1, y1, x2, y2).
110
+ image_size (Size): Size of the original image in which the bbox coordinates were retrieved.
111
+ new_size (Size): New image size to which the bbox coordinates will be scaled.
112
+
113
+ Returns:
114
+ Tensor: Updated boxes of shape (N, 4) - (x1, y1, x2, y2).
115
+ """
116
+ scale = torch.Tensor([*new_size]) / torch.Tensor([*image_size])
117
+ return boxes * scale.repeat(2).to(boxes.device)
anomalib/data/utils/download.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper to show progress bars with `urlretrieve`, check hash of file."""
2
+
3
+ # Copyright (C) 2022-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+
7
+ import hashlib
8
+ import io
9
+ import logging
10
+ import os
11
+ import re
12
+ import tarfile
13
+ from collections.abc import Iterable
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from tarfile import TarFile, TarInfo
17
+ from urllib.request import urlretrieve
18
+ from zipfile import ZipFile
19
+
20
+ from tqdm import tqdm
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class DownloadInfo:
27
+ """Info needed to download a dataset from a url."""
28
+
29
+ name: str
30
+ url: str
31
+ hashsum: str
32
+ filename: str | None = None
33
+
34
+
35
+ class DownloadProgressBar(tqdm):
36
+ """Create progress bar for urlretrieve. Subclasses `tqdm`.
37
+
38
+ For information about the parameters in constructor, refer to `tqdm`'s documentation.
39
+
40
+ Args:
41
+ iterable (Iterable | None): Iterable to decorate with a progressbar.
42
+ Leave blank to manually manage the updates.
43
+ desc (str | None): Prefix for the progressbar.
44
+ total (int | float | None): The number of expected iterations. If unspecified,
45
+ len(iterable) is used if possible. If float("inf") or as a last
46
+ resort, only basic progress statistics are displayed
47
+ (no ETA, no progressbar).
48
+ If `gui` is True and this parameter needs subsequent updating,
49
+ specify an initial arbitrary large positive number,
50
+ e.g. 9e9.
51
+ leave (bool | None): upon termination of iteration. If `None`, will leave only if `position` is `0`.
52
+ file (io.TextIOWrapper | io.StringIO | None): Specifies where to output the progress messages
53
+ (default: sys.stderr). Uses `file.write(str)` and
54
+ `file.flush()` methods. For encoding, see
55
+ `write_bytes`.
56
+ ncols (int | None): The width of the entire output message. If specified,
57
+ dynamically resizes the progressbar to stay within this bound.
58
+ If unspecified, attempts to use environment width. The
59
+ fallback is a meter width of 10 and no limit for the counter and
60
+ statistics. If 0, will not print any meter (only stats).
61
+ mininterval (float | None): Minimum progress display update interval [default: 0.1] seconds.
62
+ maxinterval (float | None): Maximum progress display update interval [default: 10] seconds.
63
+ Automatically adjusts `miniters` to correspond to `mininterval`
64
+ after long display update lag. Only works if `dynamic_miniters`
65
+ or monitor thread is enabled.
66
+ miniters (int | float | None): Minimum progress display update interval, in iterations.
67
+ If 0 and `dynamic_miniters`, will automatically adjust to equal
68
+ `mininterval` (more CPU efficient, good for tight loops).
69
+ If > 0, will skip display of specified number of iterations.
70
+ Tweak this and `mininterval` to get very efficient loops.
71
+ If your progress is erratic with both fast and slow iterations
72
+ (network, skipping items, etc) you should set miniters=1.
73
+ use_ascii (str | bool | None): If unspecified or False, use unicode (smooth blocks) to fill
74
+ the meter. The fallback is to use ASCII characters " 123456789#".
75
+ disable (bool | None): Whether to disable the entire progressbar wrapper
76
+ [default: False]. If set to None, disable on non-TTY.
77
+ unit (str | None): String that will be used to define the unit of each iteration
78
+ [default: it].
79
+ unit_scale (int | float | bool): If 1 or True, the number of iterations will be reduced/scaled
80
+ automatically and a metric prefix following the
81
+ International System of Units standard will be added
82
+ (kilo, mega, etc.) [default: False]. If any other non-zero
83
+ number, will scale `total` and `n`.
84
+ dynamic_ncols (bool | None): If set, constantly alters `ncols` and `nrows` to the
85
+ environment (allowing for window resizes) [default: False].
86
+ smoothing (float | None): Exponential moving average smoothing factor for speed estimates
87
+ (ignored in GUI mode). Ranges from 0 (average speed) to 1
88
+ (current/instantaneous speed) [default: 0.3].
89
+ bar_format (str | None): Specify a custom bar string formatting. May impact performance.
90
+ [default: '{l_bar}{bar}{r_bar}'], where
91
+ l_bar='{desc}: {percentage:3.0f}%|' and
92
+ r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
93
+ '{rate_fmt}{postfix}]'
94
+ Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
95
+ percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
96
+ rate, rate_fmt, rate_noinv, rate_noinv_fmt,
97
+ rate_inv, rate_inv_fmt, postfix, unit_divisor,
98
+ remaining, remaining_s, eta.
99
+ Note that a trailing ": " is automatically removed after {desc}
100
+ if the latter is empty.
101
+ initial (int | float | None): The initial counter value. Useful when restarting a progress
102
+ bar [default: 0]. If using float, consider specifying `{n:.3f}`
103
+ or similar in `bar_format`, or specifying `unit_scale`.
104
+ position (int | None): Specify the line offset to print this bar (starting from 0)
105
+ Automatic if unspecified.
106
+ Useful to manage multiple bars at once (eg, from threads).
107
+ postfix (dict | None): Specify additional stats to display at the end of the bar.
108
+ Calls `set_postfix(**postfix)` if possible (dict).
109
+ unit_divisor (float | None): [default: 1000], ignored unless `unit_scale` is True.
110
+ write_bytes (bool | None): If (default: None) and `file` is unspecified,
111
+ bytes will be written in Python 2. If `True` will also write
112
+ bytes. In all other cases will default to unicode.
113
+ lock_args (tuple | None): Passed to `refresh` for intermediate output
114
+ (initialisation, iterating, and updating).
115
+ nrows (int | None): The screen height. If specified, hides nested bars
116
+ outside this bound. If unspecified, attempts to use environment height.
117
+ The fallback is 20.
118
+ colour (str | None): Bar colour (e.g. 'green', '#00ff00').
119
+ delay (float | None): Don't display until [default: 0] seconds have elapsed.
120
+ gui (bool | None): WARNING: internal parameter - do not use.
121
+ Use tqdm.gui.tqdm(...) instead. If set, will attempt to use
122
+ matplotlib animations for a graphical output [default: False].
123
+
124
+
125
+ Example:
126
+ >>> with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as p_bar:
127
+ >>> urllib.request.urlretrieve(url, filename=output_path, reporthook=p_bar.update_to)
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ iterable: Iterable | None = None,
133
+ desc: str | None = None,
134
+ total: int | float | None = None,
135
+ leave: bool | None = True,
136
+ file: io.TextIOWrapper | io.StringIO | None = None,
137
+ ncols: int | None = None,
138
+ mininterval: float | None = 0.1,
139
+ maxinterval: float | None = 10.0,
140
+ miniters: int | float | None = None,
141
+ use_ascii: bool | str | None = None,
142
+ disable: bool | None = False,
143
+ unit: str | None = "it",
144
+ unit_scale: bool | int | float | None = False,
145
+ dynamic_ncols: bool | None = False,
146
+ smoothing: float | None = 0.3,
147
+ bar_format: str | None = None,
148
+ initial: int | float | None = 0,
149
+ position: int | None = None,
150
+ postfix: dict | None = None,
151
+ unit_divisor: float | None = 1000,
152
+ write_bytes: bool | None = None,
153
+ lock_args: tuple | None = None,
154
+ nrows: int | None = None,
155
+ colour: str | None = None,
156
+ delay: float | None = 0,
157
+ gui: bool | None = False,
158
+ **kwargs,
159
+ ) -> None:
160
+ super().__init__(
161
+ iterable=iterable,
162
+ desc=desc,
163
+ total=total,
164
+ leave=leave,
165
+ file=file,
166
+ ncols=ncols,
167
+ mininterval=mininterval,
168
+ maxinterval=maxinterval,
169
+ miniters=miniters,
170
+ ascii=use_ascii,
171
+ disable=disable,
172
+ unit=unit,
173
+ unit_scale=unit_scale,
174
+ dynamic_ncols=dynamic_ncols,
175
+ smoothing=smoothing,
176
+ bar_format=bar_format,
177
+ initial=initial,
178
+ position=position,
179
+ postfix=postfix,
180
+ unit_divisor=unit_divisor,
181
+ write_bytes=write_bytes,
182
+ lock_args=lock_args,
183
+ nrows=nrows,
184
+ colour=colour,
185
+ delay=delay,
186
+ gui=gui,
187
+ **kwargs,
188
+ )
189
+ self.total: int | float | None
190
+
191
+ def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size: int | None = None) -> None:
192
+ """Progress bar hook for tqdm.
193
+
194
+ Based on https://stackoverflow.com/a/53877507
195
+ The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve.
196
+ However the context needs a few parameters. Refer to the example.
197
+
198
+ Args:
199
+ chunk_number (int, optional): The current chunk being processed. Defaults to 1.
200
+ max_chunk_size (int, optional): Maximum size of each chunk. Defaults to 1.
201
+ total_size (int, optional): Total download size. Defaults to None.
202
+ """
203
+ if total_size is not None:
204
+ self.total = total_size
205
+ self.update(chunk_number * max_chunk_size - self.n)
206
+
207
+
208
+ def is_file_potentially_dangerous(file_name: str) -> bool:
209
+ """Check if a file is potentially dangerous.
210
+
211
+ Args:
212
+ file_name (str): Filename.
213
+
214
+ Returns:
215
+ bool: True if the member is potentially dangerous, False otherwise.
216
+
217
+ """
218
+ # Some example criteria. We could expand this.
219
+ unsafe_patterns = ["/etc/", "/root/"]
220
+ return any(re.search(pattern, file_name) for pattern in unsafe_patterns)
221
+
222
+
223
+ def safe_extract(tar_file: TarFile, root: Path, members: list[TarInfo]) -> None:
224
+ """Extract safe members from a tar archive.
225
+
226
+ Args:
227
+ tar_file (TarFile): TarFile object.
228
+ root (Path): Root directory where the dataset will be stored.
229
+ members (List[TarInfo]): List of safe members to be extracted.
230
+
231
+ """
232
+ for member in members:
233
+ tar_file.extract(member, root)
234
+
235
+
236
+ def generate_hash(file_path: str | Path, algorithm: str = "sha256") -> str:
237
+ """Generate a hash of a file using the specified algorithm.
238
+
239
+ Args:
240
+ file_path (str | Path): Path to the file to hash.
241
+ algorithm (str): The hashing algorithm to use (e.g., 'sha256', 'sha3_512').
242
+
243
+ Returns:
244
+ str: The hexadecimal hash string of the file.
245
+
246
+ Raises:
247
+ ValueError: If the specified hashing algorithm is not supported.
248
+ """
249
+ # Get the hashing algorithm.
250
+ try:
251
+ hasher = getattr(hashlib, algorithm)()
252
+ except AttributeError as err:
253
+ msg = f"Unsupported hashing algorithm: {algorithm}"
254
+ raise ValueError(msg) from err
255
+
256
+ # Read the file in chunks to avoid loading it all into memory
257
+ with Path(file_path).open("rb") as file:
258
+ for chunk in iter(lambda: file.read(4096), b""):
259
+ hasher.update(chunk)
260
+
261
+ # Return the computed hash value in hexadecimal format
262
+ return hasher.hexdigest()
263
+
264
+
265
+ def check_hash(file_path: Path, expected_hash: str, algorithm: str = "sha256") -> None:
266
+ """Raise value error if hash does not match the calculated hash of the file.
267
+
268
+ Args:
269
+ file_path (Path): Path to file.
270
+ expected_hash (str): Expected hash of the file.
271
+ algorithm (str): Hashing algorithm to use ('sha256', 'sha3_512', etc.).
272
+ """
273
+ # Compare the calculated hash with the expected hash
274
+ calculated_hash = generate_hash(file_path, algorithm)
275
+ if calculated_hash != expected_hash:
276
+ msg = (
277
+ f"Calculated hash {calculated_hash} of downloaded file {file_path} does not match the required hash "
278
+ f"{expected_hash}."
279
+ )
280
+ raise ValueError(msg)
281
+
282
+
283
+ def extract(file_name: Path, root: Path) -> None:
284
+ """Extract a dataset.
285
+
286
+ Args:
287
+ file_name (Path): Path of the file to be extracted.
288
+ root (Path): Root directory where the dataset will be stored.
289
+
290
+ """
291
+ logger.info("Extracting dataset into root folder.")
292
+
293
+ # Safely extract zip files
294
+ if file_name.suffix == ".zip":
295
+ with ZipFile(file_name, "r") as zip_file:
296
+ for file_info in zip_file.infolist():
297
+ if not is_file_potentially_dangerous(file_info.filename):
298
+ zip_file.extract(file_info, root)
299
+
300
+ # Safely extract tar files.
301
+ elif file_name.suffix in (".tar", ".gz", ".xz", ".tgz"):
302
+ with tarfile.open(file_name) as tar_file:
303
+ members = tar_file.getmembers()
304
+ safe_members = [member for member in members if not is_file_potentially_dangerous(member.name)]
305
+ safe_extract(tar_file, root, safe_members)
306
+
307
+ else:
308
+ msg = f"Unrecognized file format: {file_name}"
309
+ raise ValueError(msg)
310
+
311
+ logger.info("Cleaning up files.")
312
+ file_name.unlink()
313
+
314
+
315
+ def download_and_extract(root: Path, info: DownloadInfo) -> None:
316
+ """Download and extract a dataset.
317
+
318
+ Args:
319
+ root (Path): Root directory where the dataset will be stored.
320
+ info (DownloadInfo): Info needed to download the dataset.
321
+ """
322
+ root.mkdir(parents=True, exist_ok=True)
323
+
324
+ # save the compressed file in the specified root directory, using the same file name as on the server
325
+ downloaded_file_path = root / info.filename if info.filename else root / info.url.split("/")[-1]
326
+
327
+ if downloaded_file_path.exists():
328
+ logger.info("Existing dataset archive found. Skipping download stage.")
329
+ else:
330
+ logger.info("Downloading the %s dataset.", info.name)
331
+ # audit url. allowing only http:// or https://
332
+ if info.url.startswith("http://") or info.url.startswith("https://"):
333
+ with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar:
334
+ urlretrieve( # noqa: S310 # nosec B310
335
+ url=f"{info.url}",
336
+ filename=downloaded_file_path,
337
+ reporthook=progress_bar.update_to,
338
+ )
339
+ logger.info("Checking the hash of the downloaded file.")
340
+ check_hash(downloaded_file_path, info.hashsum)
341
+ else:
342
+ msg = f"Invalid URL to download dataset. Supported 'http://' or 'https://' but '{info.url}' is requested"
343
+ raise RuntimeError(msg)
344
+
345
+ extract(downloaded_file_path, root)
346
+
347
+
348
+ def is_within_directory(directory: Path, target: Path) -> bool:
349
+ """Check if a target path is located within a given directory.
350
+
351
+ Args:
352
+ directory (Path): path of the parent directory
353
+ target (Path): path of the target
354
+
355
+ Returns:
356
+ (bool): True if the target is within the directory, False otherwise
357
+ """
358
+ abs_directory = directory.resolve()
359
+ abs_target = target.resolve()
360
+
361
+ # TODO(djdameln): Replace with pathlib is_relative_to after switching to Python 3.10
362
+ # CVS-122655
363
+ prefix = os.path.commonprefix([abs_directory, abs_target])
364
+ return prefix == str(abs_directory)
anomalib/data/utils/generators/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Utilities to generate synthetic data."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from .perlin import random_2d_perlin
7
+
8
+ __all__ = ["random_2d_perlin"]