zhiqing0205
commited on
Commit
·
3de7bf6
1
Parent(s):
74acc06
Add core libraries: anomalib, dinov2, open_clip_local
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- anomalib/__init__.py +24 -0
- anomalib/callbacks/__init__.py +64 -0
- anomalib/callbacks/checkpoint.py +58 -0
- anomalib/callbacks/graph.py +61 -0
- anomalib/callbacks/metrics.py +226 -0
- anomalib/callbacks/model_loader.py +39 -0
- anomalib/callbacks/nncf/__init__.py +4 -0
- anomalib/callbacks/nncf/callback.py +106 -0
- anomalib/callbacks/nncf/utils.py +243 -0
- anomalib/callbacks/normalization/__init__.py +12 -0
- anomalib/callbacks/normalization/base.py +29 -0
- anomalib/callbacks/normalization/min_max_normalization.py +109 -0
- anomalib/callbacks/normalization/utils.py +78 -0
- anomalib/callbacks/post_processor.py +125 -0
- anomalib/callbacks/thresholding.py +197 -0
- anomalib/callbacks/tiler_configuration.py +74 -0
- anomalib/callbacks/timer.py +109 -0
- anomalib/callbacks/visualizer.py +182 -0
- anomalib/cli/__init__.py +8 -0
- anomalib/cli/cli.py +483 -0
- anomalib/cli/install.py +81 -0
- anomalib/cli/utils/__init__.py +8 -0
- anomalib/cli/utils/help_formatter.py +268 -0
- anomalib/cli/utils/installation.py +430 -0
- anomalib/cli/utils/openvino.py +32 -0
- anomalib/data/__init__.py +72 -0
- anomalib/data/base/__init__.py +18 -0
- anomalib/data/base/datamodule.py +305 -0
- anomalib/data/base/dataset.py +208 -0
- anomalib/data/base/depth.py +76 -0
- anomalib/data/base/video.py +213 -0
- anomalib/data/depth/__init__.py +20 -0
- anomalib/data/depth/folder_3d.py +433 -0
- anomalib/data/depth/mvtec_3d.py +302 -0
- anomalib/data/errors.py +19 -0
- anomalib/data/image/__init__.py +33 -0
- anomalib/data/image/btech.py +362 -0
- anomalib/data/image/folder.py +478 -0
- anomalib/data/image/kolektor.py +342 -0
- anomalib/data/image/mvtec.py +414 -0
- anomalib/data/image/mvtec_loco.py +480 -0
- anomalib/data/image/visa.py +364 -0
- anomalib/data/predict.py +52 -0
- anomalib/data/transforms/__init__.py +8 -0
- anomalib/data/transforms/center_crop.py +87 -0
- anomalib/data/utils/__init__.py +56 -0
- anomalib/data/utils/augmenter.py +172 -0
- anomalib/data/utils/boxes.py +117 -0
- anomalib/data/utils/download.py +364 -0
- anomalib/data/utils/generators/__init__.py +8 -0
anomalib/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib library for research and benchmarking."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
|
8 |
+
__version__ = "1.1.0dev"
|
9 |
+
|
10 |
+
|
11 |
+
class LearningType(str, Enum):
|
12 |
+
"""Learning type defining how the model learns from the dataset samples."""
|
13 |
+
|
14 |
+
ONE_CLASS = "one_class"
|
15 |
+
ZERO_SHOT = "zero_shot"
|
16 |
+
FEW_SHOT = "few_shot"
|
17 |
+
|
18 |
+
|
19 |
+
class TaskType(str, Enum):
|
20 |
+
"""Task type used when generating predictions on the dataset."""
|
21 |
+
|
22 |
+
CLASSIFICATION = "classification"
|
23 |
+
DETECTION = "detection"
|
24 |
+
SEGMENTATION = "segmentation"
|
anomalib/callbacks/__init__.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callbacks for Anomalib models."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from importlib import import_module
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import yaml
|
12 |
+
from jsonargparse import Namespace
|
13 |
+
from lightning.pytorch.callbacks import Callback
|
14 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
15 |
+
|
16 |
+
from .checkpoint import ModelCheckpoint
|
17 |
+
from .graph import GraphLogger
|
18 |
+
from .model_loader import LoadModelCallback
|
19 |
+
from .tiler_configuration import TilerConfigurationCallback
|
20 |
+
from .timer import TimerCallback
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
"ModelCheckpoint",
|
24 |
+
"GraphLogger",
|
25 |
+
"LoadModelCallback",
|
26 |
+
"TilerConfigurationCallback",
|
27 |
+
"TimerCallback",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]:
|
35 |
+
"""Return base callbacks for all the lightning models.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
config (DictConfig | ListConfig | Namespace): Model config
|
39 |
+
|
40 |
+
Return:
|
41 |
+
(list[Callback]): List of callbacks.
|
42 |
+
"""
|
43 |
+
logger.info("Loading the callbacks")
|
44 |
+
|
45 |
+
callbacks: list[Callback] = []
|
46 |
+
|
47 |
+
if "ckpt_path" in config.trainer and config.ckpt_path is not None:
|
48 |
+
load_model = LoadModelCallback(config.ckpt_path)
|
49 |
+
callbacks.append(load_model)
|
50 |
+
|
51 |
+
if "optimization" in config and "nncf" in config.optimization and config.optimization.nncf.apply:
|
52 |
+
# NNCF wraps torch's jit which conflicts with kornia's jit calls.
|
53 |
+
# Hence, nncf is imported only when required
|
54 |
+
nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
|
55 |
+
nncf_callback = nncf_module.NNCFCallback
|
56 |
+
nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
|
57 |
+
callbacks.append(
|
58 |
+
nncf_callback(
|
59 |
+
config=nncf_config,
|
60 |
+
export_dir=str(Path(config.project.path) / "compressed"),
|
61 |
+
),
|
62 |
+
)
|
63 |
+
|
64 |
+
return callbacks
|
anomalib/callbacks/checkpoint.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib Model Checkpoint Callback."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from lightning.pytorch import Trainer
|
7 |
+
from lightning.pytorch.callbacks import ModelCheckpoint as LightningCheckpoint
|
8 |
+
from lightning.pytorch.trainer.states import TrainerFn
|
9 |
+
|
10 |
+
from anomalib import LearningType
|
11 |
+
|
12 |
+
|
13 |
+
class ModelCheckpoint(LightningCheckpoint):
|
14 |
+
"""Anomalib Model Checkpoint Callback.
|
15 |
+
|
16 |
+
This class overrides the Lightning ModelCheckpoint callback to enable saving checkpoints without running any
|
17 |
+
training steps. This is useful for zero-/few-shot models, where the fit sequence only consists of validation.
|
18 |
+
|
19 |
+
To enable saving checkpoints without running any training steps, we need to override two checks which are being
|
20 |
+
called in the ``on_validation_end`` method of the parent class:
|
21 |
+
- ``_should_save_on_train_epoch_end``: This method checks whether the checkpoint should be saved at the end of a
|
22 |
+
training epoch, or at the end of the validation sequence. We modify this method to default to saving at the end
|
23 |
+
of the validation sequence when the model is of zero- or few-shot type, unless ``save_on_train_epoch_end`` is
|
24 |
+
specifically set by the user.
|
25 |
+
- ``_should_skip_saving_checkpoint``: This method checks whether the checkpoint should be saved at all. We modify
|
26 |
+
this method to allow saving during both the ``FITTING`` and ``VALIDATING`` states. In addition, we allow saving
|
27 |
+
if the global step has not changed since the last checkpoint, but only for zero- and few-shot models. This is
|
28 |
+
needed because both the last global step and the last checkpoint remain unchanged during zero-/few-shot
|
29 |
+
training, which would otherwise prevent saving checkpoints during validation.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
33 |
+
"""Checks whether the checkpoint should be saved.
|
34 |
+
|
35 |
+
Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow
|
36 |
+
saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models).
|
37 |
+
"""
|
38 |
+
is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
|
39 |
+
return (
|
40 |
+
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
|
41 |
+
or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit
|
42 |
+
or trainer.sanity_checking # don't save anything during sanity check
|
43 |
+
or (self._last_global_step_saved == trainer.global_step and not is_zero_or_few_shot)
|
44 |
+
)
|
45 |
+
|
46 |
+
def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool:
|
47 |
+
"""Checks whether the checkpoint should be saved at the end of a training epoch or validation sequence.
|
48 |
+
|
49 |
+
Overrides the parent method to default to saving at the end of the validation sequence when the model is of
|
50 |
+
zero- or few-shot type, unless ``save_on_train_epoch_end`` is specifically set by the user.
|
51 |
+
"""
|
52 |
+
if self._save_on_train_epoch_end is not None:
|
53 |
+
return self._save_on_train_epoch_end
|
54 |
+
|
55 |
+
if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
|
56 |
+
return False
|
57 |
+
|
58 |
+
return super()._should_save_on_train_epoch_end(trainer)
|
anomalib/callbacks/graph.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Log model graph to respective logger."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
8 |
+
|
9 |
+
from anomalib.loggers import AnomalibCometLogger, AnomalibTensorBoardLogger, AnomalibWandbLogger
|
10 |
+
|
11 |
+
|
12 |
+
class GraphLogger(Callback):
|
13 |
+
"""Log model graph to respective logger.
|
14 |
+
|
15 |
+
Examples:
|
16 |
+
Log model graph to Tensorboard
|
17 |
+
|
18 |
+
>>> from anomalib.callbacks import GraphLogger
|
19 |
+
>>> from anomalib.loggers import AnomalibTensorBoardLogger
|
20 |
+
>>> from anomalib.engine import Engine
|
21 |
+
...
|
22 |
+
>>> logger = AnomalibTensorBoardLogger()
|
23 |
+
>>> callbacks = [GraphLogger()]
|
24 |
+
>>> engine = Engine(logger=logger, callbacks=callbacks)
|
25 |
+
|
26 |
+
Log model graph to Comet
|
27 |
+
|
28 |
+
>>> from anomalib.loggers import AnomalibCometLogger
|
29 |
+
>>> from anomalib.engine import Engine
|
30 |
+
...
|
31 |
+
>>> logger = AnomalibCometLogger()
|
32 |
+
>>> callbacks = [GraphLogger()]
|
33 |
+
>>> engine = Engine(logger=logger, callbacks=callbacks)
|
34 |
+
"""
|
35 |
+
|
36 |
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
37 |
+
"""Log model graph to respective logger.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
trainer: Trainer object which contans reference to loggers.
|
41 |
+
pl_module: LightningModule object which is logged.
|
42 |
+
"""
|
43 |
+
for logger in trainer.loggers:
|
44 |
+
if isinstance(logger, AnomalibWandbLogger):
|
45 |
+
# NOTE: log graph gets populated only after one backward pass. This won't work for models which do not
|
46 |
+
# require training such as Padim
|
47 |
+
logger.watch(pl_module, log_graph=True, log="all")
|
48 |
+
break
|
49 |
+
|
50 |
+
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
51 |
+
"""Unwatch model if configured for wandb and log it model graph in Tensorboard if specified.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
trainer: Trainer object which contans reference to loggers.
|
55 |
+
pl_module: LightningModule object which is logged.
|
56 |
+
"""
|
57 |
+
for logger in trainer.loggers:
|
58 |
+
if isinstance(logger, AnomalibCometLogger | AnomalibTensorBoardLogger):
|
59 |
+
logger.log_graph(pl_module, input_array=torch.ones((1, 3, 256, 256)))
|
60 |
+
elif isinstance(logger, AnomalibWandbLogger):
|
61 |
+
logger.experiment.unwatch(pl_module)
|
anomalib/callbacks/metrics.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MetricsManager callback."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from enum import Enum
|
9 |
+
from typing import Any
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from lightning.pytorch import Callback, Trainer
|
13 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
14 |
+
|
15 |
+
from anomalib import TaskType
|
16 |
+
from anomalib.metrics import create_metric_collection
|
17 |
+
from anomalib.models import AnomalyModule
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class Device(str, Enum):
|
23 |
+
"""Device on which to compute metrics."""
|
24 |
+
|
25 |
+
CPU = "cpu"
|
26 |
+
GPU = "gpu"
|
27 |
+
|
28 |
+
|
29 |
+
class _MetricsCallback(Callback):
|
30 |
+
"""Create image and pixel-level AnomalibMetricsCollection.
|
31 |
+
|
32 |
+
This callback creates AnomalibMetricsCollection based on the
|
33 |
+
list of strings provided for image and pixel-level metrics.
|
34 |
+
After these MetricCollections are created, the callback assigns
|
35 |
+
these to the lightning module.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
task (TaskType | str): Task type of the current run.
|
39 |
+
image_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of image-level metrics.
|
40 |
+
pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None): List of pixel-level metrics.
|
41 |
+
device (str): Whether to compute metrics on cpu or gpu. Defaults to cpu.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
47 |
+
image_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None,
|
48 |
+
pixel_metrics: list[str] | str | dict[str, dict[str, Any]] | None = None,
|
49 |
+
device: Device = Device.CPU,
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
self.task = TaskType(task)
|
53 |
+
self.image_metric_names = image_metrics
|
54 |
+
self.pixel_metric_names = pixel_metrics
|
55 |
+
self.device = device
|
56 |
+
|
57 |
+
def setup(
|
58 |
+
self,
|
59 |
+
trainer: Trainer,
|
60 |
+
pl_module: AnomalyModule,
|
61 |
+
stage: str | None = None,
|
62 |
+
) -> None:
|
63 |
+
"""Set image and pixel-level AnomalibMetricsCollection within Anomalib Model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
trainer (pl.Trainer): PyTorch Lightning Trainer
|
67 |
+
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule.
|
68 |
+
stage (str | None, optional): fit, validate, test or predict. Defaults to None.
|
69 |
+
"""
|
70 |
+
del stage, trainer # this variable is not used.
|
71 |
+
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names
|
72 |
+
if isinstance(image_metric_names, str):
|
73 |
+
image_metric_names = [image_metric_names]
|
74 |
+
|
75 |
+
pixel_metric_names: list[str] | dict[str, dict[str, Any]]
|
76 |
+
if self.pixel_metric_names is None:
|
77 |
+
pixel_metric_names = []
|
78 |
+
elif self.task == TaskType.CLASSIFICATION:
|
79 |
+
pixel_metric_names = []
|
80 |
+
logger.warning(
|
81 |
+
"Cannot perform pixel-level evaluation when task type is classification. "
|
82 |
+
"Ignoring the following pixel-level metrics: %s",
|
83 |
+
self.pixel_metric_names,
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
pixel_metric_names = (
|
87 |
+
self.pixel_metric_names.copy()
|
88 |
+
if not isinstance(self.pixel_metric_names, str)
|
89 |
+
else [self.pixel_metric_names]
|
90 |
+
)
|
91 |
+
|
92 |
+
# create a separate metric collection for metrics that operate over the semantic segmentation mask
|
93 |
+
# (segmentation mask with a separate channel for each defect type)
|
94 |
+
semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = []
|
95 |
+
# currently only SPRO metric is supported as semantic segmentation metric
|
96 |
+
if "SPRO" in pixel_metric_names:
|
97 |
+
if isinstance(pixel_metric_names, list):
|
98 |
+
pixel_metric_names.remove("SPRO")
|
99 |
+
semantic_pixel_metric_names = ["SPRO"]
|
100 |
+
elif isinstance(pixel_metric_names, dict):
|
101 |
+
spro_metric = pixel_metric_names.pop("SPRO")
|
102 |
+
semantic_pixel_metric_names = {"SPRO": spro_metric}
|
103 |
+
else:
|
104 |
+
logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names))
|
105 |
+
|
106 |
+
if isinstance(pl_module, AnomalyModule):
|
107 |
+
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
|
108 |
+
if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint
|
109 |
+
new_metrics = create_metric_collection(pixel_metric_names)
|
110 |
+
for name in new_metrics:
|
111 |
+
if name not in pl_module.pixel_metrics:
|
112 |
+
pl_module.pixel_metrics.add_metrics(new_metrics[name])
|
113 |
+
else:
|
114 |
+
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
|
115 |
+
pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_")
|
116 |
+
self._set_threshold(pl_module)
|
117 |
+
|
118 |
+
def on_validation_epoch_start(
|
119 |
+
self,
|
120 |
+
trainer: Trainer,
|
121 |
+
pl_module: AnomalyModule,
|
122 |
+
) -> None:
|
123 |
+
del trainer # Unused argument.
|
124 |
+
|
125 |
+
pl_module.image_metrics.reset()
|
126 |
+
pl_module.pixel_metrics.reset()
|
127 |
+
pl_module.semantic_pixel_metrics.reset()
|
128 |
+
|
129 |
+
def on_validation_batch_end(
|
130 |
+
self,
|
131 |
+
trainer: Trainer,
|
132 |
+
pl_module: AnomalyModule,
|
133 |
+
outputs: STEP_OUTPUT | None,
|
134 |
+
batch: Any, # noqa: ANN401
|
135 |
+
batch_idx: int,
|
136 |
+
dataloader_idx: int = 0,
|
137 |
+
) -> None:
|
138 |
+
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
|
139 |
+
|
140 |
+
if outputs is not None:
|
141 |
+
self._outputs_to_device(outputs)
|
142 |
+
self._update_metrics(pl_module, outputs)
|
143 |
+
|
144 |
+
def on_validation_epoch_end(
|
145 |
+
self,
|
146 |
+
trainer: Trainer,
|
147 |
+
pl_module: AnomalyModule,
|
148 |
+
) -> None:
|
149 |
+
del trainer # Unused argument.
|
150 |
+
|
151 |
+
self._set_threshold(pl_module)
|
152 |
+
self._log_metrics(pl_module)
|
153 |
+
|
154 |
+
def on_test_epoch_start(
|
155 |
+
self,
|
156 |
+
trainer: Trainer,
|
157 |
+
pl_module: AnomalyModule,
|
158 |
+
) -> None:
|
159 |
+
del trainer # Unused argument.
|
160 |
+
|
161 |
+
pl_module.image_metrics.reset()
|
162 |
+
pl_module.pixel_metrics.reset()
|
163 |
+
pl_module.semantic_pixel_metrics.reset()
|
164 |
+
|
165 |
+
def on_test_batch_end(
|
166 |
+
self,
|
167 |
+
trainer: Trainer,
|
168 |
+
pl_module: AnomalyModule,
|
169 |
+
outputs: STEP_OUTPUT | None,
|
170 |
+
batch: Any, # noqa: ANN401
|
171 |
+
batch_idx: int,
|
172 |
+
dataloader_idx: int = 0,
|
173 |
+
) -> None:
|
174 |
+
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
|
175 |
+
|
176 |
+
if outputs is not None:
|
177 |
+
self._outputs_to_device(outputs)
|
178 |
+
self._update_metrics(pl_module, outputs)
|
179 |
+
|
180 |
+
def on_test_epoch_end(
|
181 |
+
self,
|
182 |
+
trainer: Trainer,
|
183 |
+
pl_module: AnomalyModule,
|
184 |
+
) -> None:
|
185 |
+
del trainer # Unused argument.
|
186 |
+
|
187 |
+
self._log_metrics(pl_module)
|
188 |
+
|
189 |
+
def _set_threshold(self, pl_module: AnomalyModule) -> None:
|
190 |
+
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
|
191 |
+
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
|
192 |
+
pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
|
193 |
+
|
194 |
+
def _update_metrics(
|
195 |
+
self,
|
196 |
+
pl_module: AnomalyModule,
|
197 |
+
output: STEP_OUTPUT,
|
198 |
+
) -> None:
|
199 |
+
pl_module.image_metrics.to(self.device)
|
200 |
+
pl_module.image_metrics.update(output["pred_scores"], output["label"].int())
|
201 |
+
if "mask" in output and "anomaly_maps" in output:
|
202 |
+
pl_module.pixel_metrics.to(self.device)
|
203 |
+
pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
|
204 |
+
if "semantic_mask" in output and "anomaly_maps" in output:
|
205 |
+
pl_module.semantic_pixel_metrics.to(self.device)
|
206 |
+
pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"])
|
207 |
+
|
208 |
+
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
|
209 |
+
if isinstance(output, dict):
|
210 |
+
for key, value in output.items():
|
211 |
+
output[key] = self._outputs_to_device(value)
|
212 |
+
elif isinstance(output, torch.Tensor):
|
213 |
+
output = output.to(self.device)
|
214 |
+
elif isinstance(output, list):
|
215 |
+
for i, value in enumerate(output):
|
216 |
+
output[i] = self._outputs_to_device(value)
|
217 |
+
return output
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def _log_metrics(pl_module: AnomalyModule) -> None:
|
221 |
+
"""Log computed performance metrics."""
|
222 |
+
pl_module.log_dict(pl_module.image_metrics, prog_bar=True)
|
223 |
+
if pl_module.pixel_metrics.update_called:
|
224 |
+
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False)
|
225 |
+
if pl_module.semantic_pixel_metrics.update_called:
|
226 |
+
pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False)
|
anomalib/callbacks/model_loader.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callback that loads model weights from the state dict."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from lightning.pytorch import Callback, Trainer
|
11 |
+
|
12 |
+
from anomalib.models.components import AnomalyModule
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class LoadModelCallback(Callback):
|
18 |
+
"""Callback that loads the model weights from the state dict.
|
19 |
+
|
20 |
+
Examples:
|
21 |
+
>>> from anomalib.callbacks import LoadModelCallback
|
22 |
+
>>> from anomalib.engine import Engine
|
23 |
+
...
|
24 |
+
>>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")]
|
25 |
+
>>> engine = Engine(callbacks=callbacks)
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, weights_path: str) -> None:
|
29 |
+
self.weights_path = weights_path
|
30 |
+
|
31 |
+
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
|
32 |
+
"""Call when inference begins.
|
33 |
+
|
34 |
+
Loads the model weights from ``weights_path`` into the PyTorch module.
|
35 |
+
"""
|
36 |
+
del trainer, stage # These variables are not used.
|
37 |
+
|
38 |
+
logger.info("Loading the model from %s", self.weights_path)
|
39 |
+
pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])
|
anomalib/callbacks/nncf/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Integration NNCF."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
anomalib/callbacks/nncf/callback.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callbacks for NNCF optimization."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import subprocess # nosec B404
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import TYPE_CHECKING, Any
|
10 |
+
|
11 |
+
import lightning.pytorch as pl
|
12 |
+
from lightning.pytorch import Callback
|
13 |
+
from nncf import NNCFConfig
|
14 |
+
from nncf.torch import register_default_init_args
|
15 |
+
|
16 |
+
from anomalib.callbacks.nncf.utils import InitLoader, wrap_nncf_model
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from nncf.api.compression import CompressionAlgorithmController
|
20 |
+
|
21 |
+
|
22 |
+
class NNCFCallback(Callback):
|
23 |
+
"""Callback for NNCF compression.
|
24 |
+
|
25 |
+
Assumes that the pl module contains a 'model' attribute, which is
|
26 |
+
the PyTorch module that must be compressed.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
config (dict): NNCF Configuration
|
30 |
+
export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved.
|
31 |
+
If None model will not be exported.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, config: dict, export_dir: str | None = None) -> None:
|
35 |
+
self.export_dir = export_dir
|
36 |
+
self.config = NNCFConfig(config)
|
37 |
+
self.nncf_ctrl: CompressionAlgorithmController | None = None
|
38 |
+
|
39 |
+
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None:
|
40 |
+
"""Call when fit or test begins.
|
41 |
+
|
42 |
+
Takes the pytorch model and wraps it using the compression controller
|
43 |
+
so that it is ready for nncf fine-tuning.
|
44 |
+
"""
|
45 |
+
del stage # `stage` variable is not used.
|
46 |
+
|
47 |
+
if self.nncf_ctrl is not None:
|
48 |
+
return
|
49 |
+
|
50 |
+
# Get validate subset to initialize quantization,
|
51 |
+
# because train subset does not contain anomalous images.
|
52 |
+
init_loader = InitLoader(trainer.datamodule.val_dataloader())
|
53 |
+
config = register_default_init_args(self.config, init_loader)
|
54 |
+
|
55 |
+
self.nncf_ctrl, pl_module.model = wrap_nncf_model(
|
56 |
+
model=pl_module.model,
|
57 |
+
config=config,
|
58 |
+
dataloader=trainer.datamodule.train_dataloader(),
|
59 |
+
init_state_dict=None, # type: ignore[arg-type]
|
60 |
+
)
|
61 |
+
|
62 |
+
def on_train_batch_start(
|
63 |
+
self,
|
64 |
+
trainer: pl.Trainer,
|
65 |
+
pl_module: pl.LightningModule,
|
66 |
+
batch: Any, # noqa: ANN401
|
67 |
+
batch_idx: int,
|
68 |
+
unused: int = 0,
|
69 |
+
) -> None:
|
70 |
+
"""Call when the train batch begins.
|
71 |
+
|
72 |
+
Prepare compression method to continue training the model in the next step.
|
73 |
+
"""
|
74 |
+
del trainer, pl_module, batch, batch_idx, unused # These variables are not used.
|
75 |
+
|
76 |
+
if self.nncf_ctrl:
|
77 |
+
self.nncf_ctrl.scheduler.step()
|
78 |
+
|
79 |
+
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
80 |
+
"""Call when the train epoch starts.
|
81 |
+
|
82 |
+
Prepare compression method to continue training the model in the next epoch.
|
83 |
+
"""
|
84 |
+
del trainer, pl_module # `trainer` and `pl_module` variables are not used.
|
85 |
+
|
86 |
+
if self.nncf_ctrl:
|
87 |
+
self.nncf_ctrl.scheduler.epoch_step()
|
88 |
+
|
89 |
+
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
90 |
+
"""Call when the train ends.
|
91 |
+
|
92 |
+
Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR.
|
93 |
+
"""
|
94 |
+
del trainer, pl_module # `trainer` and `pl_module` variables are not used.
|
95 |
+
|
96 |
+
if self.export_dir is None or self.nncf_ctrl is None:
|
97 |
+
return
|
98 |
+
|
99 |
+
Path(self.export_dir).mkdir(parents=True, exist_ok=True)
|
100 |
+
onnx_path = str(Path(self.export_dir) / "model_nncf.onnx")
|
101 |
+
self.nncf_ctrl.export_model(onnx_path)
|
102 |
+
|
103 |
+
optimize_command = ["mo", "--input_model", onnx_path, "--output_dir", self.export_dir]
|
104 |
+
# TODO(samet-akcay): Check if mo can be done via python API
|
105 |
+
# CVS-122665
|
106 |
+
subprocess.run(optimize_command, check=True) # noqa: S603 # nosec B603
|
anomalib/callbacks/nncf/utils.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for NNCf optimization."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from copy import copy
|
9 |
+
from typing import TYPE_CHECKING, Any
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from nncf import NNCFConfig
|
13 |
+
from nncf.api.compression import CompressionAlgorithmController
|
14 |
+
from nncf.torch import create_compressed_model, load_state, register_default_init_args
|
15 |
+
from nncf.torch.initialization import PTInitializingDataLoader
|
16 |
+
from nncf.torch.nncf_network import NNCFNetwork
|
17 |
+
from torch import nn
|
18 |
+
from torch.utils.data.dataloader import DataLoader
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from collections.abc import Iterator
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(name="NNCF compression")
|
25 |
+
|
26 |
+
|
27 |
+
class InitLoader(PTInitializingDataLoader):
|
28 |
+
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""
|
29 |
+
|
30 |
+
def __init__(self, data_loader: DataLoader) -> None:
|
31 |
+
super().__init__(data_loader)
|
32 |
+
self._data_loader_iter: Iterator
|
33 |
+
|
34 |
+
def __iter__(self) -> "InitLoader":
|
35 |
+
"""Create iterator for dataloader."""
|
36 |
+
self._data_loader_iter = iter(self._data_loader)
|
37 |
+
return self
|
38 |
+
|
39 |
+
def __next__(self) -> torch.Tensor:
|
40 |
+
"""Return next item from dataloader iterator."""
|
41 |
+
loaded_item = next(self._data_loader_iter)
|
42 |
+
return loaded_item["image"]
|
43 |
+
|
44 |
+
def get_inputs(self, dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]:
|
45 |
+
"""Get input to model.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
(dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during
|
49 |
+
the initialization process
|
50 |
+
"""
|
51 |
+
return (dataloader_output,), {}
|
52 |
+
|
53 |
+
def get_target(self, _): # noqa: ANN001, ANN201
|
54 |
+
"""Return structure for ground truth in loss criterion based on dataloader output.
|
55 |
+
|
56 |
+
This implementation does not do anything and is a placeholder.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
None
|
60 |
+
"""
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
def wrap_nncf_model(
|
65 |
+
model: nn.Module,
|
66 |
+
config: dict,
|
67 |
+
dataloader: DataLoader,
|
68 |
+
init_state_dict: dict,
|
69 |
+
) -> tuple[CompressionAlgorithmController, NNCFNetwork]:
|
70 |
+
"""Wrap model by NNCF.
|
71 |
+
|
72 |
+
:param model: Anomalib model.
|
73 |
+
:param config: NNCF config.
|
74 |
+
:param dataloader: Dataloader for initialization of NNCF model.
|
75 |
+
:param init_state_dict: Opti
|
76 |
+
:return: compression controller, compressed model
|
77 |
+
"""
|
78 |
+
nncf_config = NNCFConfig.from_dict(config)
|
79 |
+
|
80 |
+
if not dataloader and not init_state_dict:
|
81 |
+
logger.warning(
|
82 |
+
"Either dataloader or NNCF pre-trained "
|
83 |
+
"model checkpoint should be set. Without this, "
|
84 |
+
"quantizers will not be initialized",
|
85 |
+
)
|
86 |
+
|
87 |
+
compression_state = None
|
88 |
+
resuming_state_dict = None
|
89 |
+
if init_state_dict:
|
90 |
+
resuming_state_dict = init_state_dict.get("model")
|
91 |
+
compression_state = init_state_dict.get("compression_state")
|
92 |
+
|
93 |
+
if dataloader:
|
94 |
+
init_loader = InitLoader(dataloader)
|
95 |
+
nncf_config = register_default_init_args(nncf_config, init_loader)
|
96 |
+
|
97 |
+
nncf_ctrl, nncf_model = create_compressed_model(
|
98 |
+
model=model,
|
99 |
+
config=nncf_config,
|
100 |
+
dump_graphs=False,
|
101 |
+
compression_state=compression_state,
|
102 |
+
)
|
103 |
+
|
104 |
+
if resuming_state_dict:
|
105 |
+
load_state(nncf_model, resuming_state_dict, is_resume=True)
|
106 |
+
|
107 |
+
return nncf_ctrl, nncf_model
|
108 |
+
|
109 |
+
|
110 |
+
def is_state_nncf(state: dict) -> bool:
|
111 |
+
"""Check if state is the result of NNCF-compressed model."""
|
112 |
+
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
|
113 |
+
|
114 |
+
|
115 |
+
def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict:
|
116 |
+
"""Compose NNCf config by selected options.
|
117 |
+
|
118 |
+
:param nncf_config:
|
119 |
+
:param enabled_options:
|
120 |
+
:return: config
|
121 |
+
"""
|
122 |
+
optimisation_parts = nncf_config
|
123 |
+
optimisation_parts_to_choose = []
|
124 |
+
if "order_of_parts" in optimisation_parts:
|
125 |
+
# The result of applying the changes from optimisation parts
|
126 |
+
# may depend on the order of applying the changes
|
127 |
+
# (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
|
128 |
+
# but for sparsity it is required `total_epochs=50`)
|
129 |
+
# So, user can define `order_of_parts` in the optimisation_config
|
130 |
+
# to specify the order of applying the parts.
|
131 |
+
order_of_parts = optimisation_parts["order_of_parts"]
|
132 |
+
if not isinstance(order_of_parts, list):
|
133 |
+
msg = 'The field "order_of_parts" in optimization config should be a list'
|
134 |
+
raise TypeError(msg)
|
135 |
+
|
136 |
+
for part in enabled_options:
|
137 |
+
if part not in order_of_parts:
|
138 |
+
msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}"
|
139 |
+
raise ValueError(msg)
|
140 |
+
|
141 |
+
optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]
|
142 |
+
|
143 |
+
if "base" not in optimisation_parts:
|
144 |
+
msg = 'Error: the optimisation config does not contain the "base" part'
|
145 |
+
raise KeyError(msg)
|
146 |
+
nncf_config_part = optimisation_parts["base"]
|
147 |
+
|
148 |
+
for part in optimisation_parts_to_choose:
|
149 |
+
if part not in optimisation_parts:
|
150 |
+
msg = f'Error: the optimisation config does not contain the part "{part}"'
|
151 |
+
raise KeyError(msg)
|
152 |
+
optimisation_part_dict = optimisation_parts[part]
|
153 |
+
try:
|
154 |
+
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
|
155 |
+
except AssertionError as cur_error:
|
156 |
+
err_descr = (
|
157 |
+
f"Error during merging the parts of nncf configs:\n"
|
158 |
+
f"the current part={part}, "
|
159 |
+
f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
|
160 |
+
f"The error is:\n{cur_error}"
|
161 |
+
)
|
162 |
+
raise RuntimeError(err_descr) from None
|
163 |
+
|
164 |
+
return nncf_config_part
|
165 |
+
|
166 |
+
|
167 |
+
def merge_dicts_and_lists_b_into_a(
|
168 |
+
a: dict[Any, Any] | list[Any],
|
169 |
+
b: dict[Any, Any] | list[Any],
|
170 |
+
) -> dict[Any, Any] | list[Any]:
|
171 |
+
"""Merge dict configs.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
a (dict[Any, Any] | list[Any]): First dict or list.
|
175 |
+
b (dict[Any, Any] | list[Any]): Second dict or list.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
dict[Any, Any] | list[Any]: Merged dict or list.
|
179 |
+
"""
|
180 |
+
return _merge_dicts_and_lists_b_into_a(a, b, "")
|
181 |
+
|
182 |
+
|
183 |
+
def _merge_dicts_and_lists_b_into_a(
|
184 |
+
a: dict[Any, Any] | list[Any],
|
185 |
+
b: dict[Any, Any] | list[Any],
|
186 |
+
cur_key: int | str | None = None,
|
187 |
+
) -> dict[Any, Any] | list[Any]:
|
188 |
+
"""Merge dict configs.
|
189 |
+
|
190 |
+
* works with usual dicts and lists and derived types
|
191 |
+
* supports merging of lists (by concatenating the lists)
|
192 |
+
* makes recursive merging for dict + dict case
|
193 |
+
* overwrites when merging scalar into scalar
|
194 |
+
Note that we merge b into a (whereas Config makes merge a into b),
|
195 |
+
since otherwise the order of list merging is counter-intuitive.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
a (dict[Any, Any] | list[Any]): First dict or list.
|
199 |
+
b (dict[Any, Any] | list[Any]): Second dict or list.
|
200 |
+
cur_key (int | str | None, optional): key for current level of recursion. Defaults to None.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
dict[Any, Any] | list[Any]: Merged dict or list.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str:
|
207 |
+
_key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`"
|
208 |
+
return (
|
209 |
+
f"Error in merging parts of config: different types {_key_str},"
|
210 |
+
f" type(a) = {type(_a)},"
|
211 |
+
f" type(b) = {type(_b)}"
|
212 |
+
)
|
213 |
+
|
214 |
+
if not (isinstance(a, dict | list)):
|
215 |
+
msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}"
|
216 |
+
raise TypeError(msg)
|
217 |
+
|
218 |
+
if not (isinstance(b, dict | list)):
|
219 |
+
raise TypeError(_err_str(a, b, cur_key))
|
220 |
+
|
221 |
+
if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)):
|
222 |
+
raise TypeError(_err_str(a, b, cur_key))
|
223 |
+
|
224 |
+
if isinstance(a, list) and isinstance(b, list):
|
225 |
+
# the main diff w.r.t. mmcf.Config -- merging of lists
|
226 |
+
return a + b
|
227 |
+
|
228 |
+
a = copy(a)
|
229 |
+
for k in b:
|
230 |
+
if k not in a:
|
231 |
+
a[k] = copy(b[k])
|
232 |
+
continue
|
233 |
+
new_cur_key = str(cur_key) + "." + k if cur_key else k
|
234 |
+
if isinstance(a[k], dict | list):
|
235 |
+
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
|
236 |
+
continue
|
237 |
+
|
238 |
+
if any(isinstance(b[k], t) for t in [dict, list]):
|
239 |
+
raise TypeError(_err_str(a[k], b[k], new_cur_key))
|
240 |
+
|
241 |
+
# suppose here that a[k] and b[k] are scalars, just overwrite
|
242 |
+
a[k] = b[k]
|
243 |
+
return a
|
anomalib/callbacks/normalization/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Normalization callbacks.
|
2 |
+
|
3 |
+
Note: These callbacks are used within the Engine.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Copyright (C) 2023-2024 Intel Corporation
|
7 |
+
# SPDX-License-Identifier: Apache-2.0
|
8 |
+
|
9 |
+
from .min_max_normalization import _MinMaxNormalizationCallback
|
10 |
+
from .utils import get_normalization_callback
|
11 |
+
|
12 |
+
__all__ = ["get_normalization_callback", "_MinMaxNormalizationCallback"]
|
anomalib/callbacks/normalization/base.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base Normalization Callback."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
from lightning.pytorch import Callback
|
9 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
10 |
+
|
11 |
+
from anomalib.models.components import AnomalyModule
|
12 |
+
|
13 |
+
|
14 |
+
class NormalizationCallback(Callback, ABC):
|
15 |
+
"""Base normalization callback."""
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
@abstractmethod
|
19 |
+
def _normalize_batch(batch: STEP_OUTPUT, pl_module: AnomalyModule) -> None:
|
20 |
+
"""Normalize an output batch.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
batch (dict[str, torch.Tensor]): Output batch.
|
24 |
+
pl_module (AnomalyModule): AnomalyModule instance.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
dict[str, torch.Tensor]: Normalized batch.
|
28 |
+
"""
|
29 |
+
raise NotImplementedError
|
anomalib/callbacks/normalization/min_max_normalization.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomaly Score Normalization Callback that uses min-max normalization."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from lightning.pytorch import Trainer
|
11 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
12 |
+
|
13 |
+
from anomalib.metrics import MinMax
|
14 |
+
from anomalib.models.components import AnomalyModule
|
15 |
+
from anomalib.utils.normalization.min_max import normalize
|
16 |
+
|
17 |
+
from .base import NormalizationCallback
|
18 |
+
|
19 |
+
|
20 |
+
class _MinMaxNormalizationCallback(NormalizationCallback):
|
21 |
+
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization.
|
22 |
+
|
23 |
+
Note: This callback is set within the Engine.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
|
27 |
+
"""Add min_max metrics to normalization metrics."""
|
28 |
+
del trainer, stage # These variables are not used.
|
29 |
+
|
30 |
+
if not hasattr(pl_module, "normalization_metrics"):
|
31 |
+
pl_module.normalization_metrics = MinMax().cpu()
|
32 |
+
elif not isinstance(pl_module.normalization_metrics, MinMax):
|
33 |
+
msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
|
34 |
+
raise AttributeError(
|
35 |
+
msg,
|
36 |
+
)
|
37 |
+
|
38 |
+
def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
|
39 |
+
"""Call when the test begins."""
|
40 |
+
del trainer # `trainer` variable is not used.
|
41 |
+
|
42 |
+
for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics):
|
43 |
+
if metric is not None:
|
44 |
+
metric.set_threshold(0.5)
|
45 |
+
|
46 |
+
def on_validation_batch_end(
|
47 |
+
self,
|
48 |
+
trainer: Trainer,
|
49 |
+
pl_module: AnomalyModule,
|
50 |
+
outputs: STEP_OUTPUT,
|
51 |
+
batch: Any, # noqa: ANN401
|
52 |
+
batch_idx: int,
|
53 |
+
dataloader_idx: int = 0,
|
54 |
+
) -> None:
|
55 |
+
"""Call when the validation batch ends, update the min and max observed values."""
|
56 |
+
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
|
57 |
+
|
58 |
+
if "anomaly_maps" in outputs:
|
59 |
+
pl_module.normalization_metrics(outputs["anomaly_maps"])
|
60 |
+
elif "box_scores" in outputs:
|
61 |
+
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
|
62 |
+
elif "pred_scores" in outputs:
|
63 |
+
pl_module.normalization_metrics(outputs["pred_scores"])
|
64 |
+
else:
|
65 |
+
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
|
66 |
+
raise ValueError(msg)
|
67 |
+
|
68 |
+
def on_test_batch_end(
|
69 |
+
self,
|
70 |
+
trainer: Trainer,
|
71 |
+
pl_module: AnomalyModule,
|
72 |
+
outputs: STEP_OUTPUT | None,
|
73 |
+
batch: Any, # noqa: ANN401
|
74 |
+
batch_idx: int,
|
75 |
+
dataloader_idx: int = 0,
|
76 |
+
) -> None:
|
77 |
+
"""Call when the test batch ends, normalizes the predicted scores and anomaly maps."""
|
78 |
+
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
|
79 |
+
|
80 |
+
self._normalize_batch(outputs, pl_module)
|
81 |
+
|
82 |
+
def on_predict_batch_end(
|
83 |
+
self,
|
84 |
+
trainer: Trainer,
|
85 |
+
pl_module: AnomalyModule,
|
86 |
+
outputs: Any, # noqa: ANN401
|
87 |
+
batch: Any, # noqa: ANN401
|
88 |
+
batch_idx: int,
|
89 |
+
dataloader_idx: int = 0,
|
90 |
+
) -> None:
|
91 |
+
"""Call when the predict batch ends, normalizes the predicted scores and anomaly maps."""
|
92 |
+
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
|
93 |
+
|
94 |
+
self._normalize_batch(outputs, pl_module)
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401
|
98 |
+
"""Normalize a batch of predictions."""
|
99 |
+
image_threshold = pl_module.image_threshold.value.cpu()
|
100 |
+
pixel_threshold = pl_module.pixel_threshold.value.cpu()
|
101 |
+
stats = pl_module.normalization_metrics.cpu()
|
102 |
+
if "pred_scores" in outputs:
|
103 |
+
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
|
104 |
+
if "anomaly_maps" in outputs:
|
105 |
+
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
|
106 |
+
if "box_scores" in outputs:
|
107 |
+
outputs["box_scores"] = [
|
108 |
+
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
|
109 |
+
]
|
anomalib/callbacks/normalization/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Normalization callback utils."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
|
8 |
+
from lightning.pytorch import Callback
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
from anomalib.utils.normalization import NormalizationMethod
|
12 |
+
from anomalib.utils.types import NORMALIZATION
|
13 |
+
|
14 |
+
from .min_max_normalization import _MinMaxNormalizationCallback
|
15 |
+
|
16 |
+
|
17 |
+
def get_normalization_callback(
|
18 |
+
normalization_method: NORMALIZATION = NormalizationMethod.MIN_MAX,
|
19 |
+
) -> Callback | None:
|
20 |
+
"""Return normalization object.
|
21 |
+
|
22 |
+
normalization_method is an instance of ``Callback``, it is returned as is.
|
23 |
+
|
24 |
+
if normalization_method is of type ``NormalizationMethod``, then a new class is created based on the type of
|
25 |
+
normalization_method.
|
26 |
+
|
27 |
+
Otherwise it expects a dictionary containing class_path and init_args.
|
28 |
+
normalization_method:
|
29 |
+
class_path: MinMaxNormalizer
|
30 |
+
init_args:
|
31 |
+
-
|
32 |
+
-
|
33 |
+
|
34 |
+
Example:
|
35 |
+
>>> normalizer = get_normalization_callback(NormalizationMethod.MIN_MAX)
|
36 |
+
or
|
37 |
+
>>> normalizer = get_normalization_callback("min_max")
|
38 |
+
or
|
39 |
+
>>> normalizer = get_normalization_callback({"class_path": "MinMaxNormalizationCallback", "init_args": {}})
|
40 |
+
or
|
41 |
+
>>> normalizer = get_normalization_callback(MinMaxNormalizationCallback())
|
42 |
+
"""
|
43 |
+
normalizer: Callback | None
|
44 |
+
if isinstance(normalization_method, NormalizationMethod | str):
|
45 |
+
normalizer = _get_normalizer_from_method(NormalizationMethod(normalization_method))
|
46 |
+
elif isinstance(normalization_method, Callback):
|
47 |
+
normalizer = normalization_method
|
48 |
+
elif isinstance(normalization_method, DictConfig):
|
49 |
+
normalizer = _parse_normalizer_config(normalization_method)
|
50 |
+
else:
|
51 |
+
msg = f"Unknown normalizer type {normalization_method}"
|
52 |
+
raise TypeError(msg)
|
53 |
+
return normalizer
|
54 |
+
|
55 |
+
|
56 |
+
def _get_normalizer_from_method(normalization_method: NormalizationMethod | str) -> Callback | None:
|
57 |
+
if normalization_method == NormalizationMethod.NONE:
|
58 |
+
normalizer = None
|
59 |
+
elif normalization_method == NormalizationMethod.MIN_MAX:
|
60 |
+
normalizer = _MinMaxNormalizationCallback()
|
61 |
+
else:
|
62 |
+
msg = f"Unknown normalization method {normalization_method}"
|
63 |
+
raise ValueError(msg)
|
64 |
+
return normalizer
|
65 |
+
|
66 |
+
|
67 |
+
def _parse_normalizer_config(normalization_method: DictConfig) -> Callback:
|
68 |
+
class_path = normalization_method.class_path
|
69 |
+
init_args = normalization_method.init_args
|
70 |
+
|
71 |
+
if len(class_path.split(".")) == 1:
|
72 |
+
module_path = "anomalib.utils.callbacks.normalization"
|
73 |
+
else:
|
74 |
+
module_path = ".".join(class_path.split(".")[:-1])
|
75 |
+
class_path = class_path.split(".")[-1]
|
76 |
+
module = importlib.import_module(module_path)
|
77 |
+
class_ = getattr(module, class_path)
|
78 |
+
return class_(**init_args)
|
anomalib/callbacks/post_processor.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callback that attaches necessary pre/post-processing to the model."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from lightning import Callback
|
11 |
+
from lightning.pytorch import Trainer
|
12 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
13 |
+
|
14 |
+
from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
|
15 |
+
from anomalib.models import AnomalyModule
|
16 |
+
|
17 |
+
|
18 |
+
class _PostProcessorCallback(Callback):
|
19 |
+
"""Applies post-processing to the model outputs.
|
20 |
+
|
21 |
+
Note: This callback is set within the Engine.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self) -> None:
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
def on_validation_batch_end(
|
28 |
+
self,
|
29 |
+
trainer: Trainer,
|
30 |
+
pl_module: AnomalyModule,
|
31 |
+
outputs: STEP_OUTPUT | None,
|
32 |
+
batch: Any, # noqa: ANN401
|
33 |
+
batch_idx: int,
|
34 |
+
dataloader_idx: int = 0,
|
35 |
+
) -> None:
|
36 |
+
del batch, batch_idx, dataloader_idx # Unused arguments.
|
37 |
+
|
38 |
+
if outputs is not None:
|
39 |
+
self.post_process(trainer, pl_module, outputs)
|
40 |
+
|
41 |
+
def on_test_batch_end(
|
42 |
+
self,
|
43 |
+
trainer: Trainer,
|
44 |
+
pl_module: AnomalyModule,
|
45 |
+
outputs: STEP_OUTPUT | None,
|
46 |
+
batch: Any, # noqa: ANN401
|
47 |
+
batch_idx: int,
|
48 |
+
dataloader_idx: int = 0,
|
49 |
+
) -> None:
|
50 |
+
del batch, batch_idx, dataloader_idx # Unused arguments.
|
51 |
+
|
52 |
+
if outputs is not None:
|
53 |
+
self.post_process(trainer, pl_module, outputs)
|
54 |
+
|
55 |
+
def on_predict_batch_end(
|
56 |
+
self,
|
57 |
+
trainer: Trainer,
|
58 |
+
pl_module: AnomalyModule,
|
59 |
+
outputs: Any, # noqa: ANN401
|
60 |
+
batch: Any, # noqa: ANN401
|
61 |
+
batch_idx: int,
|
62 |
+
dataloader_idx: int = 0,
|
63 |
+
) -> None:
|
64 |
+
del batch, batch_idx, dataloader_idx # Unused arguments.
|
65 |
+
|
66 |
+
if outputs is not None:
|
67 |
+
self.post_process(trainer, pl_module, outputs)
|
68 |
+
|
69 |
+
def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
|
70 |
+
if isinstance(outputs, dict):
|
71 |
+
self._post_process(outputs)
|
72 |
+
if trainer.predicting or trainer.testing:
|
73 |
+
self._compute_scores_and_labels(pl_module, outputs)
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def _compute_scores_and_labels(
|
77 |
+
pl_module: AnomalyModule,
|
78 |
+
outputs: dict[str, Any],
|
79 |
+
) -> None:
|
80 |
+
if "pred_scores" in outputs:
|
81 |
+
outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value
|
82 |
+
if "anomaly_maps" in outputs:
|
83 |
+
outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value
|
84 |
+
if "pred_boxes" not in outputs:
|
85 |
+
outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes(
|
86 |
+
outputs["pred_masks"],
|
87 |
+
outputs["anomaly_maps"],
|
88 |
+
)
|
89 |
+
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
|
90 |
+
# apply thresholding to boxes
|
91 |
+
if "box_scores" in outputs and "box_labels" not in outputs:
|
92 |
+
# apply threshold to assign normal/anomalous label to boxes
|
93 |
+
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]]
|
94 |
+
outputs["box_labels"] = [labels.int() for labels in is_anomalous]
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def _post_process(outputs: STEP_OUTPUT) -> None:
|
98 |
+
"""Compute labels based on model predictions."""
|
99 |
+
if isinstance(outputs, dict):
|
100 |
+
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
|
101 |
+
# infer image scores from anomaly maps
|
102 |
+
outputs["pred_scores"] = (
|
103 |
+
outputs["anomaly_maps"] # noqa: PD011
|
104 |
+
.reshape(outputs["anomaly_maps"].shape[0], -1)
|
105 |
+
.max(dim=1)
|
106 |
+
.values
|
107 |
+
)
|
108 |
+
elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs:
|
109 |
+
# infer image score from bbox confidence scores
|
110 |
+
outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float()
|
111 |
+
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)):
|
112 |
+
if boxes.numel():
|
113 |
+
outputs["pred_scores"][idx] = scores.max().item()
|
114 |
+
|
115 |
+
if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
|
116 |
+
# create anomaly maps from bbox predictions for thresholding and evaluation
|
117 |
+
image_size: tuple[int, int] = outputs["image"].shape[-2:]
|
118 |
+
pred_boxes: torch.Tensor = outputs["pred_boxes"]
|
119 |
+
box_scores: torch.Tensor = outputs["box_scores"]
|
120 |
+
|
121 |
+
outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size)
|
122 |
+
|
123 |
+
if "boxes" in outputs:
|
124 |
+
true_boxes: list[torch.Tensor] = outputs["boxes"]
|
125 |
+
outputs["mask"] = boxes_to_masks(true_boxes, image_size)
|
anomalib/callbacks/thresholding.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Thresholding callback."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from lightning.pytorch import Callback, Trainer
|
11 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
12 |
+
from omegaconf import DictConfig, ListConfig
|
13 |
+
|
14 |
+
from anomalib.metrics.threshold import BaseThreshold
|
15 |
+
from anomalib.models import AnomalyModule
|
16 |
+
from anomalib.utils.types import THRESHOLD
|
17 |
+
|
18 |
+
|
19 |
+
class _ThresholdCallback(Callback):
|
20 |
+
"""Setup/apply thresholding.
|
21 |
+
|
22 |
+
Note: This callback is set within the Engine.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
threshold: THRESHOLD = "F1AdaptiveThreshold",
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
self._initialize_thresholds(threshold)
|
31 |
+
self.image_threshold: BaseThreshold
|
32 |
+
self.pixel_threshold: BaseThreshold
|
33 |
+
|
34 |
+
def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str) -> None:
|
35 |
+
del trainer, stage # Unused arguments.
|
36 |
+
if not hasattr(pl_module, "image_threshold"):
|
37 |
+
pl_module.image_threshold = self.image_threshold
|
38 |
+
if not hasattr(pl_module, "pixel_threshold"):
|
39 |
+
pl_module.pixel_threshold = self.pixel_threshold
|
40 |
+
|
41 |
+
def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
|
42 |
+
del trainer # Unused argument.
|
43 |
+
self._reset(pl_module)
|
44 |
+
|
45 |
+
def on_validation_batch_end(
|
46 |
+
self,
|
47 |
+
trainer: Trainer,
|
48 |
+
pl_module: AnomalyModule,
|
49 |
+
outputs: STEP_OUTPUT | None,
|
50 |
+
batch: Any, # noqa: ANN401
|
51 |
+
batch_idx: int,
|
52 |
+
dataloader_idx: int = 0,
|
53 |
+
) -> None:
|
54 |
+
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
|
55 |
+
if outputs is not None:
|
56 |
+
self._outputs_to_cpu(outputs)
|
57 |
+
self._update(pl_module, outputs)
|
58 |
+
|
59 |
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
|
60 |
+
del trainer # Unused argument.
|
61 |
+
self._compute(pl_module)
|
62 |
+
|
63 |
+
def _initialize_thresholds(
|
64 |
+
self,
|
65 |
+
threshold: THRESHOLD,
|
66 |
+
) -> None:
|
67 |
+
"""Initialize ``self.image_threshold`` and ``self.pixel_threshold``.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
threshold (THRESHOLD):
|
71 |
+
Threshold configuration
|
72 |
+
|
73 |
+
Example:
|
74 |
+
>>> _initialize_thresholds(F1AdaptiveThreshold())
|
75 |
+
or
|
76 |
+
>>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
|
77 |
+
or configuration
|
78 |
+
|
79 |
+
For more details on configuration see :fun:`_load_from_config`
|
80 |
+
|
81 |
+
Raises:
|
82 |
+
ValueError: Unknown threshold class or incorrect configuration
|
83 |
+
"""
|
84 |
+
# TODO(djdameln): Add tests for each case
|
85 |
+
# CVS-122661
|
86 |
+
# When only a single threshold class is passed.
|
87 |
+
# This initializes image and pixel thresholds with the same class
|
88 |
+
# >>> _initialize_thresholds(F1AdaptiveThreshold())
|
89 |
+
if isinstance(threshold, BaseThreshold):
|
90 |
+
self.image_threshold = threshold
|
91 |
+
self.pixel_threshold = threshold.clone()
|
92 |
+
|
93 |
+
# When a tuple of threshold classes are passed
|
94 |
+
# >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
|
95 |
+
elif isinstance(threshold, tuple) and isinstance(threshold[0], BaseThreshold):
|
96 |
+
self.image_threshold = threshold[0]
|
97 |
+
self.pixel_threshold = threshold[1]
|
98 |
+
# When the passed threshold is not an instance of a Threshold class.
|
99 |
+
elif isinstance(threshold, str | DictConfig | ListConfig | list):
|
100 |
+
self._load_from_config(threshold)
|
101 |
+
else:
|
102 |
+
msg = f"Invalid threshold type {type(threshold)}"
|
103 |
+
raise TypeError(msg)
|
104 |
+
|
105 |
+
def _load_from_config(self, threshold: DictConfig | str | ListConfig | list[dict[str, str | float]]) -> None:
|
106 |
+
"""Load the thresholding class based on the config.
|
107 |
+
|
108 |
+
Example:
|
109 |
+
threshold: F1AdaptiveThreshold
|
110 |
+
or
|
111 |
+
threshold:
|
112 |
+
class_path: F1AdaptiveThreshold
|
113 |
+
init_args:
|
114 |
+
-
|
115 |
+
or
|
116 |
+
threshold:
|
117 |
+
- F1AdaptiveThreshold
|
118 |
+
- F1AdaptiveThreshold
|
119 |
+
or
|
120 |
+
threshold:
|
121 |
+
- class_path: F1AdaptiveThreshold
|
122 |
+
init_args:
|
123 |
+
-
|
124 |
+
- class_path: F1AdaptiveThreshold
|
125 |
+
"""
|
126 |
+
if isinstance(threshold, str | DictConfig):
|
127 |
+
self.image_threshold = self._get_threshold_from_config(threshold)
|
128 |
+
self.pixel_threshold = self.image_threshold.clone()
|
129 |
+
elif isinstance(threshold, ListConfig | list):
|
130 |
+
self.image_threshold = self._get_threshold_from_config(threshold[0])
|
131 |
+
self.pixel_threshold = self._get_threshold_from_config(threshold[1])
|
132 |
+
else:
|
133 |
+
msg = f"Invalid threshold config {threshold}"
|
134 |
+
raise TypeError(msg)
|
135 |
+
|
136 |
+
def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> BaseThreshold:
|
137 |
+
"""Return the instantiated threshold object.
|
138 |
+
|
139 |
+
Example:
|
140 |
+
>>> _get_threshold_from_config(F1AdaptiveThreshold)
|
141 |
+
or
|
142 |
+
>>> config = DictConfig({
|
143 |
+
... "class_path": "ManualThreshold",
|
144 |
+
... "init_args": {"default_value": 0.7}
|
145 |
+
... })
|
146 |
+
>>> __get_threshold_from_config(config)
|
147 |
+
or
|
148 |
+
>>> config = DictConfig({
|
149 |
+
... "class_path": "anomalib.metrics.threshold.F1AdaptiveThreshold"
|
150 |
+
... })
|
151 |
+
>>> __get_threshold_from_config(config)
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
(BaseThreshold): Instance of threshold object.
|
155 |
+
"""
|
156 |
+
if isinstance(threshold, str):
|
157 |
+
threshold = DictConfig({"class_path": threshold})
|
158 |
+
|
159 |
+
class_path = threshold["class_path"]
|
160 |
+
init_args = threshold.get("init_args", {})
|
161 |
+
|
162 |
+
if len(class_path.split(".")) == 1:
|
163 |
+
module_path = "anomalib.metrics.threshold"
|
164 |
+
|
165 |
+
else:
|
166 |
+
module_path = ".".join(class_path.split(".")[:-1])
|
167 |
+
class_path = class_path.split(".")[-1]
|
168 |
+
|
169 |
+
module = importlib.import_module(module_path)
|
170 |
+
class_ = getattr(module, class_path)
|
171 |
+
return class_(**init_args)
|
172 |
+
|
173 |
+
def _reset(self, pl_module: AnomalyModule) -> None:
|
174 |
+
pl_module.image_threshold.reset()
|
175 |
+
pl_module.pixel_threshold.reset()
|
176 |
+
|
177 |
+
def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
|
178 |
+
if isinstance(output, dict):
|
179 |
+
for key, value in output.items():
|
180 |
+
output[key] = self._outputs_to_cpu(value)
|
181 |
+
elif isinstance(output, torch.Tensor):
|
182 |
+
output = output.cpu()
|
183 |
+
return output
|
184 |
+
|
185 |
+
def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
|
186 |
+
pl_module.image_threshold.cpu()
|
187 |
+
pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int())
|
188 |
+
if "mask" in outputs and "anomaly_maps" in outputs:
|
189 |
+
pl_module.pixel_threshold.cpu()
|
190 |
+
pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int())
|
191 |
+
|
192 |
+
def _compute(self, pl_module: AnomalyModule) -> None:
|
193 |
+
pl_module.image_threshold.compute()
|
194 |
+
if pl_module.pixel_threshold._update_called: # noqa: SLF001
|
195 |
+
pl_module.pixel_threshold.compute()
|
196 |
+
else:
|
197 |
+
pl_module.pixel_threshold.value = pl_module.image_threshold.value
|
anomalib/callbacks/tiler_configuration.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tiler Callback."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from collections.abc import Sequence
|
8 |
+
|
9 |
+
import lightning.pytorch as pl
|
10 |
+
from lightning.pytorch.callbacks import Callback
|
11 |
+
|
12 |
+
from anomalib.data.utils.tiler import ImageUpscaleMode, Tiler
|
13 |
+
from anomalib.models.components import AnomalyModule
|
14 |
+
|
15 |
+
__all__ = ["TilerConfigurationCallback"]
|
16 |
+
|
17 |
+
|
18 |
+
class TilerConfigurationCallback(Callback):
|
19 |
+
"""Tiler Configuration Callback."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
enable: bool = False,
|
24 |
+
tile_size: int | Sequence = 256,
|
25 |
+
stride: int | Sequence | None = None,
|
26 |
+
remove_border_count: int = 0,
|
27 |
+
mode: ImageUpscaleMode = ImageUpscaleMode.PADDING,
|
28 |
+
) -> None:
|
29 |
+
"""Set tiling configuration from the command line.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
enable (bool): Boolean to enable tiling operation.
|
33 |
+
Defaults to False.
|
34 |
+
tile_size ([int | Sequence]): Tile size.
|
35 |
+
Defaults to 256.
|
36 |
+
stride ([int | Sequence]): Stride to move tiles on the image.
|
37 |
+
remove_border_count (int, optional): Number of pixels to remove from the image before
|
38 |
+
tiling. Defaults to 0.
|
39 |
+
mode (str, optional): Up-scaling mode when untiling overlapping tiles.
|
40 |
+
Defaults to "padding".
|
41 |
+
tile_count (SupportsIndex, optional): Number of random tiles to sample from the image.
|
42 |
+
Defaults to 4.
|
43 |
+
"""
|
44 |
+
self.enable = enable
|
45 |
+
self.tile_size = tile_size
|
46 |
+
self.stride = stride
|
47 |
+
self.remove_border_count = remove_border_count
|
48 |
+
self.mode = mode
|
49 |
+
|
50 |
+
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None:
|
51 |
+
"""Set Tiler object within Anomalib Model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
trainer (pl.Trainer): PyTorch Lightning Trainer
|
55 |
+
pl_module (pl.LightningModule): Anomalib Model that inherits pl LightningModule.
|
56 |
+
stage (str | None, optional): fit, validate, test or predict. Defaults to None.
|
57 |
+
|
58 |
+
Raises:
|
59 |
+
ValueError: When Anomalib Model doesn't contain ``Tiler`` object, it means the model
|
60 |
+
doesn not support tiling operation.
|
61 |
+
"""
|
62 |
+
del trainer, stage # These variables are not used.
|
63 |
+
|
64 |
+
if self.enable:
|
65 |
+
if isinstance(pl_module, AnomalyModule) and hasattr(pl_module.model, "tiler"):
|
66 |
+
pl_module.model.tiler = Tiler(
|
67 |
+
tile_size=self.tile_size,
|
68 |
+
stride=self.stride,
|
69 |
+
remove_border_count=self.remove_border_count,
|
70 |
+
mode=self.mode,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
msg = "Model does not support tiling."
|
74 |
+
raise ValueError(msg)
|
anomalib/callbacks/timer.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callback to measure training and testing time of a PyTorch Lightning module."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import time
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class TimerCallback(Callback):
|
16 |
+
"""Callback that measures the training and testing time of a PyTorch Lightning module.
|
17 |
+
|
18 |
+
Examples:
|
19 |
+
>>> from anomalib.callbacks import TimerCallback
|
20 |
+
>>> from anomalib.engine import Engine
|
21 |
+
...
|
22 |
+
>>> callbacks = [TimerCallback()]
|
23 |
+
>>> engine = Engine(callbacks=callbacks)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self) -> None:
|
27 |
+
self.start: float
|
28 |
+
self.num_images: int = 0
|
29 |
+
|
30 |
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
31 |
+
"""Call when fit begins.
|
32 |
+
|
33 |
+
Sets the start time to the time training started.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
trainer (Trainer): PyTorch Lightning trainer.
|
37 |
+
pl_module (LightningModule): Current training module.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
None
|
41 |
+
"""
|
42 |
+
del trainer, pl_module # These variables are not used.
|
43 |
+
|
44 |
+
self.start = time.time()
|
45 |
+
|
46 |
+
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
47 |
+
"""Call when fit ends.
|
48 |
+
|
49 |
+
Prints the time taken for training.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
trainer (Trainer): PyTorch Lightning trainer.
|
53 |
+
pl_module (LightningModule): Current training module.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
None
|
57 |
+
"""
|
58 |
+
del trainer, pl_module # Unused arguments.
|
59 |
+
logger.info("Training took %5.2f seconds", (time.time() - self.start))
|
60 |
+
|
61 |
+
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
62 |
+
"""Call when the test begins.
|
63 |
+
|
64 |
+
Sets the start time to the time testing started.
|
65 |
+
Goes over all the test dataloaders and adds the number of images in each.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
trainer (Trainer): PyTorch Lightning trainer.
|
69 |
+
pl_module (LightningModule): Current training module.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
None
|
73 |
+
"""
|
74 |
+
del pl_module # Unused argument.
|
75 |
+
|
76 |
+
self.start = time.time()
|
77 |
+
self.num_images = 0
|
78 |
+
|
79 |
+
if trainer.test_dataloaders is not None: # Check to placate Mypy.
|
80 |
+
if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader):
|
81 |
+
self.num_images += len(trainer.test_dataloaders.dataset)
|
82 |
+
else:
|
83 |
+
for dataloader in trainer.test_dataloaders:
|
84 |
+
self.num_images += len(dataloader.dataset)
|
85 |
+
|
86 |
+
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
87 |
+
"""Call when the test ends.
|
88 |
+
|
89 |
+
Prints the time taken for testing and the throughput in frames per second.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
trainer (Trainer): PyTorch Lightning trainer.
|
93 |
+
pl_module (LightningModule): Current training module.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
None
|
97 |
+
"""
|
98 |
+
del pl_module # Unused argument.
|
99 |
+
|
100 |
+
testing_time = time.time() - self.start
|
101 |
+
output = f"Testing took {testing_time} seconds\nThroughput "
|
102 |
+
if trainer.test_dataloaders is not None:
|
103 |
+
if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader):
|
104 |
+
test_data_loader = trainer.test_dataloaders
|
105 |
+
else:
|
106 |
+
test_data_loader = trainer.test_dataloaders[0]
|
107 |
+
output += f"(batch_size={test_data_loader.batch_size})"
|
108 |
+
output += f" : {self.num_images/testing_time} FPS"
|
109 |
+
logger.info(output)
|
anomalib/callbacks/visualizer.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Visualizer Callback.
|
2 |
+
|
3 |
+
This is assigned by Anomalib Engine internally.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Copyright (C) 2024 Intel Corporation
|
7 |
+
# SPDX-License-Identifier: Apache-2.0
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, cast
|
12 |
+
|
13 |
+
from lightning.pytorch import Callback, Trainer
|
14 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
15 |
+
|
16 |
+
from anomalib.data.utils.image import save_image, show_image
|
17 |
+
from anomalib.loggers import AnomalibWandbLogger
|
18 |
+
from anomalib.loggers.base import ImageLoggerBase
|
19 |
+
from anomalib.models import AnomalyModule
|
20 |
+
from anomalib.utils.visualization import (
|
21 |
+
BaseVisualizer,
|
22 |
+
GeneratorResult,
|
23 |
+
VisualizationStep,
|
24 |
+
)
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class _VisualizationCallback(Callback):
|
30 |
+
"""Callback for visualization that is used internally by the Engine.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
visualizers (BaseVisualizer | list[BaseVisualizer]):
|
34 |
+
Visualizer objects that are used for computing the visualizations. Defaults to None.
|
35 |
+
save (bool, optional): Save the image. Defaults to False.
|
36 |
+
root (Path | None, optional): The path to save the images. Defaults to None.
|
37 |
+
log (bool, optional): Log the images into the loggers. Defaults to False.
|
38 |
+
show (bool, optional): Show the images. Defaults to False.
|
39 |
+
|
40 |
+
Example:
|
41 |
+
>>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
|
42 |
+
>>> visualization_callback = _VisualizationCallback(
|
43 |
+
... visualizers=visualizers,
|
44 |
+
... save=True,
|
45 |
+
... root="results/images"
|
46 |
+
... )
|
47 |
+
|
48 |
+
CLI
|
49 |
+
$ anomalib train --model Padim --data MVTec \
|
50 |
+
--visualization.visualizers ImageVisualizer \
|
51 |
+
--visualization.visualizers+=MetricsVisualizer
|
52 |
+
or
|
53 |
+
$ anomalib train --model Padim --data MVTec \
|
54 |
+
--visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'
|
55 |
+
|
56 |
+
Raises:
|
57 |
+
ValueError: Incase `root` is None and `save` is True.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
visualizers: BaseVisualizer | list[BaseVisualizer],
|
63 |
+
save: bool = False,
|
64 |
+
root: Path | None = None,
|
65 |
+
log: bool = False,
|
66 |
+
show: bool = False,
|
67 |
+
) -> None:
|
68 |
+
self.save = save
|
69 |
+
if save and root is None:
|
70 |
+
msg = "`root` must be provided if save is True"
|
71 |
+
raise ValueError(msg)
|
72 |
+
self.root: Path = root if root is not None else Path() # need this check for mypy
|
73 |
+
self.log = log
|
74 |
+
self.show = show
|
75 |
+
self.generators = visualizers if isinstance(visualizers, list) else [visualizers]
|
76 |
+
|
77 |
+
def on_test_batch_end(
|
78 |
+
self,
|
79 |
+
trainer: Trainer,
|
80 |
+
pl_module: AnomalyModule,
|
81 |
+
outputs: STEP_OUTPUT | None,
|
82 |
+
batch: Any, # noqa: ANN401
|
83 |
+
batch_idx: int,
|
84 |
+
dataloader_idx: int = 0,
|
85 |
+
) -> None:
|
86 |
+
for generator in self.generators:
|
87 |
+
if generator.visualize_on == VisualizationStep.BATCH:
|
88 |
+
for result in generator(
|
89 |
+
trainer=trainer,
|
90 |
+
pl_module=pl_module,
|
91 |
+
outputs=outputs,
|
92 |
+
batch=batch,
|
93 |
+
batch_idx=batch_idx,
|
94 |
+
dataloader_idx=dataloader_idx,
|
95 |
+
):
|
96 |
+
if self.save:
|
97 |
+
if result.file_name is None:
|
98 |
+
msg = "``save`` is set to ``True`` but file name is ``None``"
|
99 |
+
raise ValueError(msg)
|
100 |
+
|
101 |
+
# Get the filename to save the image.
|
102 |
+
# Filename is split based on the datamodule name and category.
|
103 |
+
# For example, if the filename is `MVTec/bottle/000.png`, then the
|
104 |
+
# filename is split based on `MVTec/bottle` and `000.png` is saved.
|
105 |
+
if trainer.datamodule is not None:
|
106 |
+
filename = str(result.file_name).split(
|
107 |
+
sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}",
|
108 |
+
)[-1]
|
109 |
+
else:
|
110 |
+
filename = Path(result.file_name).name
|
111 |
+
save_image(image=result.image, root=self.root, filename=filename)
|
112 |
+
if self.show:
|
113 |
+
show_image(image=result.image, title=str(result.file_name))
|
114 |
+
if self.log:
|
115 |
+
self._add_to_logger(result, pl_module, trainer)
|
116 |
+
|
117 |
+
def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
|
118 |
+
for generator in self.generators:
|
119 |
+
if generator.visualize_on == VisualizationStep.STAGE_END:
|
120 |
+
for result in generator(trainer=trainer, pl_module=pl_module):
|
121 |
+
if self.save:
|
122 |
+
if result.file_name is None:
|
123 |
+
msg = "``save`` is set to ``True`` but file name is ``None``"
|
124 |
+
raise ValueError(msg)
|
125 |
+
save_image(image=result.image, root=self.root, filename=result.file_name)
|
126 |
+
if self.show:
|
127 |
+
show_image(image=result.image, title=str(result.file_name))
|
128 |
+
if self.log:
|
129 |
+
self._add_to_logger(result, pl_module, trainer)
|
130 |
+
|
131 |
+
for logger in trainer.loggers:
|
132 |
+
if isinstance(logger, AnomalibWandbLogger):
|
133 |
+
logger.save()
|
134 |
+
|
135 |
+
def on_predict_batch_end(
|
136 |
+
self,
|
137 |
+
trainer: Trainer,
|
138 |
+
pl_module: AnomalyModule,
|
139 |
+
outputs: STEP_OUTPUT | None,
|
140 |
+
batch: Any, # noqa: ANN401
|
141 |
+
batch_idx: int,
|
142 |
+
dataloader_idx: int = 0,
|
143 |
+
) -> None:
|
144 |
+
return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
145 |
+
|
146 |
+
def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
|
147 |
+
return self.on_test_end(trainer, pl_module)
|
148 |
+
|
149 |
+
def _add_to_logger(
|
150 |
+
self,
|
151 |
+
result: GeneratorResult,
|
152 |
+
module: AnomalyModule,
|
153 |
+
trainer: Trainer,
|
154 |
+
) -> None:
|
155 |
+
"""Add image to logger.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
result (GeneratorResult): Output from the generators.
|
159 |
+
module (AnomalyModule): LightningModule from which the global step is extracted.
|
160 |
+
trainer (Trainer): Trainer object.
|
161 |
+
"""
|
162 |
+
# Store names of logger and the logger in a dict
|
163 |
+
available_loggers = {
|
164 |
+
type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
|
165 |
+
for logger in trainer.loggers
|
166 |
+
}
|
167 |
+
# save image to respective logger
|
168 |
+
if result.file_name is None:
|
169 |
+
msg = "File name is None"
|
170 |
+
raise ValueError(msg)
|
171 |
+
filename = result.file_name
|
172 |
+
image = result.image
|
173 |
+
for log_to in available_loggers:
|
174 |
+
# check if logger object is same as the requested object
|
175 |
+
if isinstance(available_loggers[log_to], ImageLoggerBase):
|
176 |
+
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
|
177 |
+
_name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
|
178 |
+
logger.add_image(
|
179 |
+
image=image,
|
180 |
+
name=_name,
|
181 |
+
global_step=module.global_step,
|
182 |
+
)
|
anomalib/cli/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib CLI."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from .cli import AnomalibCLI
|
7 |
+
|
8 |
+
__all__ = ["AnomalibCLI"]
|
anomalib/cli/cli.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib CLI."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from collections.abc import Callable, Sequence
|
8 |
+
from functools import partial
|
9 |
+
from pathlib import Path
|
10 |
+
from types import MethodType
|
11 |
+
from typing import Any
|
12 |
+
|
13 |
+
from jsonargparse import ActionConfigFile, ArgumentParser, Namespace
|
14 |
+
from jsonargparse._actions import _ActionSubCommands
|
15 |
+
from rich import traceback
|
16 |
+
|
17 |
+
from anomalib import TaskType, __version__
|
18 |
+
from anomalib.cli.utils.help_formatter import CustomHelpFormatter, get_short_docstring
|
19 |
+
from anomalib.cli.utils.openvino import add_openvino_export_arguments
|
20 |
+
from anomalib.loggers import configure_logger
|
21 |
+
|
22 |
+
traceback.install()
|
23 |
+
logger = logging.getLogger("anomalib.cli")
|
24 |
+
|
25 |
+
_LIGHTNING_AVAILABLE = True
|
26 |
+
try:
|
27 |
+
from lightning.pytorch import Trainer
|
28 |
+
from torch.utils.data import DataLoader, Dataset
|
29 |
+
|
30 |
+
from anomalib.data import AnomalibDataModule
|
31 |
+
from anomalib.engine import Engine
|
32 |
+
from anomalib.metrics.threshold import BaseThreshold
|
33 |
+
from anomalib.models import AnomalyModule
|
34 |
+
from anomalib.utils.config import update_config
|
35 |
+
|
36 |
+
except ImportError:
|
37 |
+
_LIGHTNING_AVAILABLE = False
|
38 |
+
|
39 |
+
|
40 |
+
class AnomalibCLI:
|
41 |
+
"""Implementation of a fully configurable CLI tool for anomalib.
|
42 |
+
|
43 |
+
The advantage of this tool is its flexibility to configure the pipeline
|
44 |
+
from both the CLI and a configuration file (.yaml or .json). It is even
|
45 |
+
possible to use both the CLI and a configuration file simultaneously.
|
46 |
+
For more details, the reader could refer to PyTorch Lightning CLI
|
47 |
+
documentation.
|
48 |
+
|
49 |
+
``save_config_kwargs`` is set to ``overwrite=True`` so that the
|
50 |
+
``SaveConfigCallback`` overwrites the config if it already exists.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, args: Sequence[str] | None = None) -> None:
|
54 |
+
self.parser = self.init_parser()
|
55 |
+
self.subcommand_parsers: dict[str, ArgumentParser] = {}
|
56 |
+
self.subcommand_method_arguments: dict[str, list[str]] = {}
|
57 |
+
self.add_subcommands()
|
58 |
+
self.config = self.parser.parse_args(args=args)
|
59 |
+
self.subcommand = self.config["subcommand"]
|
60 |
+
if _LIGHTNING_AVAILABLE:
|
61 |
+
self.before_instantiate_classes()
|
62 |
+
self.instantiate_classes()
|
63 |
+
self._run_subcommand()
|
64 |
+
|
65 |
+
def init_parser(self, **kwargs) -> ArgumentParser:
|
66 |
+
"""Method that instantiates the argument parser."""
|
67 |
+
kwargs.setdefault("dump_header", [f"anomalib=={__version__}"])
|
68 |
+
parser = ArgumentParser(formatter_class=CustomHelpFormatter, **kwargs)
|
69 |
+
parser.add_argument(
|
70 |
+
"-c",
|
71 |
+
"--config",
|
72 |
+
action=ActionConfigFile,
|
73 |
+
help="Path to a configuration file in json or yaml format.",
|
74 |
+
)
|
75 |
+
return parser
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def subcommands() -> dict[str, set[str]]:
|
79 |
+
"""Skip predict subcommand as it is added later."""
|
80 |
+
return {
|
81 |
+
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
|
82 |
+
"validate": {"model", "dataloaders", "datamodule"},
|
83 |
+
"test": {"model", "dataloaders", "datamodule"},
|
84 |
+
}
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def anomalib_subcommands() -> dict[str, dict[str, str]]:
|
88 |
+
"""Return a dictionary of subcommands and their description."""
|
89 |
+
return {
|
90 |
+
"train": {"description": "Fit the model and then call test on the trained model."},
|
91 |
+
"predict": {"description": "Run inference on a model."},
|
92 |
+
"export": {"description": "Export the model to ONNX or OpenVINO format."},
|
93 |
+
}
|
94 |
+
|
95 |
+
def add_subcommands(self, **kwargs) -> None:
|
96 |
+
"""Initialize base subcommands and add anomalib specific on top of it."""
|
97 |
+
parser_subcommands = self.parser.add_subcommands()
|
98 |
+
|
99 |
+
# Extra subcommand: install
|
100 |
+
self._set_install_subcommand(parser_subcommands)
|
101 |
+
|
102 |
+
if not _LIGHTNING_AVAILABLE:
|
103 |
+
# If environment is not configured to use pl, do not add a subcommand for Engine.
|
104 |
+
return
|
105 |
+
|
106 |
+
# Add Trainer subcommands
|
107 |
+
for subcommand in self.subcommands():
|
108 |
+
sub_parser = self.init_parser(**kwargs)
|
109 |
+
|
110 |
+
fn = getattr(Trainer, subcommand)
|
111 |
+
# extract the first line description in the docstring for the subcommand help message
|
112 |
+
description = get_short_docstring(fn)
|
113 |
+
subparser_kwargs = kwargs.get(subcommand, {})
|
114 |
+
subparser_kwargs.setdefault("description", description)
|
115 |
+
|
116 |
+
self.subcommand_parsers[subcommand] = sub_parser
|
117 |
+
parser_subcommands.add_subcommand(subcommand, sub_parser, help=description)
|
118 |
+
self.add_trainer_arguments(sub_parser, subcommand)
|
119 |
+
|
120 |
+
# Add anomalib subcommands
|
121 |
+
for subcommand in self.anomalib_subcommands():
|
122 |
+
sub_parser = self.init_parser(**kwargs)
|
123 |
+
|
124 |
+
self.subcommand_parsers[subcommand] = sub_parser
|
125 |
+
parser_subcommands.add_subcommand(
|
126 |
+
subcommand,
|
127 |
+
sub_parser,
|
128 |
+
help=self.anomalib_subcommands()[subcommand]["description"],
|
129 |
+
)
|
130 |
+
# add arguments to subcommand
|
131 |
+
getattr(self, f"add_{subcommand}_arguments")(sub_parser)
|
132 |
+
|
133 |
+
def add_arguments_to_parser(self, parser: ArgumentParser) -> None:
|
134 |
+
"""Extend trainer's arguments to add engine arguments.
|
135 |
+
|
136 |
+
.. note::
|
137 |
+
Since ``Engine`` parameters are manually added, any change to the
|
138 |
+
``Engine`` class should be reflected manually.
|
139 |
+
"""
|
140 |
+
from anomalib.callbacks.normalization import get_normalization_callback
|
141 |
+
|
142 |
+
parser.add_function_arguments(get_normalization_callback, "normalization")
|
143 |
+
parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION)
|
144 |
+
parser.add_argument(
|
145 |
+
"--metrics.image",
|
146 |
+
type=list[str] | str | dict[str, dict[str, Any]] | None,
|
147 |
+
default=["F1Score", "AUROC"],
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--metrics.pixel",
|
151 |
+
type=list[str] | str | dict[str, dict[str, Any]] | None,
|
152 |
+
default=None,
|
153 |
+
required=False,
|
154 |
+
)
|
155 |
+
parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
|
156 |
+
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
|
157 |
+
if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"):
|
158 |
+
parser.link_arguments("task", "data.init_args.task")
|
159 |
+
parser.add_argument(
|
160 |
+
"--default_root_dir",
|
161 |
+
type=Path,
|
162 |
+
help="Path to save the results.",
|
163 |
+
default=Path("./results"),
|
164 |
+
)
|
165 |
+
parser.link_arguments("default_root_dir", "trainer.default_root_dir")
|
166 |
+
# TODO(ashwinvaidya17): Tiling should also be a category of its own
|
167 |
+
# CVS-122659
|
168 |
+
|
169 |
+
def add_trainer_arguments(self, parser: ArgumentParser, subcommand: str) -> None:
|
170 |
+
"""Add train arguments to the parser."""
|
171 |
+
self._add_default_arguments_to_parser(parser)
|
172 |
+
self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True)
|
173 |
+
parser.add_subclass_arguments(
|
174 |
+
AnomalyModule,
|
175 |
+
"model",
|
176 |
+
fail_untyped=False,
|
177 |
+
required=True,
|
178 |
+
)
|
179 |
+
parser.add_subclass_arguments(AnomalibDataModule, "data")
|
180 |
+
self.add_arguments_to_parser(parser)
|
181 |
+
skip: set[str | int] = set(self.subcommands()[subcommand])
|
182 |
+
added = parser.add_method_arguments(
|
183 |
+
Trainer,
|
184 |
+
subcommand,
|
185 |
+
skip=skip,
|
186 |
+
)
|
187 |
+
self.subcommand_method_arguments[subcommand] = added
|
188 |
+
|
189 |
+
def add_train_arguments(self, parser: ArgumentParser) -> None:
|
190 |
+
"""Add train arguments to the parser."""
|
191 |
+
self._add_default_arguments_to_parser(parser)
|
192 |
+
self._add_trainer_arguments_to_parser(parser, add_optimizer=True, add_scheduler=True)
|
193 |
+
parser.add_subclass_arguments(
|
194 |
+
AnomalyModule,
|
195 |
+
"model",
|
196 |
+
fail_untyped=False,
|
197 |
+
required=True,
|
198 |
+
)
|
199 |
+
parser.add_subclass_arguments(AnomalibDataModule, "data")
|
200 |
+
self.add_arguments_to_parser(parser)
|
201 |
+
added = parser.add_method_arguments(
|
202 |
+
Engine,
|
203 |
+
"train",
|
204 |
+
skip={"model", "datamodule", "val_dataloaders", "test_dataloaders", "train_dataloaders"},
|
205 |
+
)
|
206 |
+
self.subcommand_method_arguments["train"] = added
|
207 |
+
|
208 |
+
def add_predict_arguments(self, parser: ArgumentParser) -> None:
|
209 |
+
"""Add predict arguments to the parser."""
|
210 |
+
self._add_default_arguments_to_parser(parser)
|
211 |
+
self._add_trainer_arguments_to_parser(parser)
|
212 |
+
parser.add_subclass_arguments(
|
213 |
+
AnomalyModule,
|
214 |
+
"model",
|
215 |
+
fail_untyped=False,
|
216 |
+
required=True,
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--data",
|
220 |
+
type=Dataset | AnomalibDataModule | DataLoader | str | Path,
|
221 |
+
required=True,
|
222 |
+
)
|
223 |
+
added = parser.add_method_arguments(
|
224 |
+
Engine,
|
225 |
+
"predict",
|
226 |
+
skip={"model", "dataloaders", "datamodule", "dataset", "data_path"},
|
227 |
+
)
|
228 |
+
self.subcommand_method_arguments["predict"] = added
|
229 |
+
self.add_arguments_to_parser(parser)
|
230 |
+
|
231 |
+
def add_export_arguments(self, parser: ArgumentParser) -> None:
|
232 |
+
"""Add export arguments to the parser."""
|
233 |
+
self._add_default_arguments_to_parser(parser)
|
234 |
+
self._add_trainer_arguments_to_parser(parser)
|
235 |
+
parser.add_subclass_arguments(
|
236 |
+
AnomalyModule,
|
237 |
+
"model",
|
238 |
+
fail_untyped=False,
|
239 |
+
required=True,
|
240 |
+
)
|
241 |
+
added = parser.add_method_arguments(
|
242 |
+
Engine,
|
243 |
+
"export",
|
244 |
+
skip={"ov_args", "model"},
|
245 |
+
)
|
246 |
+
self.subcommand_method_arguments["export"] = added
|
247 |
+
add_openvino_export_arguments(parser)
|
248 |
+
self.add_arguments_to_parser(parser)
|
249 |
+
|
250 |
+
def _set_install_subcommand(self, action_subcommand: _ActionSubCommands) -> None:
|
251 |
+
sub_parser = ArgumentParser(formatter_class=CustomHelpFormatter)
|
252 |
+
sub_parser.add_argument(
|
253 |
+
"--option",
|
254 |
+
help="Install the full or optional-dependencies.",
|
255 |
+
default="full",
|
256 |
+
type=str,
|
257 |
+
choices=["full", "core", "dev", "loggers", "notebooks", "openvino"],
|
258 |
+
)
|
259 |
+
sub_parser.add_argument(
|
260 |
+
"-v",
|
261 |
+
"--verbose",
|
262 |
+
help="Set Logger level to INFO",
|
263 |
+
action="store_true",
|
264 |
+
)
|
265 |
+
|
266 |
+
self.subcommand_parsers["install"] = sub_parser
|
267 |
+
action_subcommand.add_subcommand(
|
268 |
+
"install",
|
269 |
+
sub_parser,
|
270 |
+
help="Install the full-package for anomalib.",
|
271 |
+
)
|
272 |
+
|
273 |
+
def before_instantiate_classes(self) -> None:
|
274 |
+
"""Modify the configuration to properly instantiate classes and sets up tiler."""
|
275 |
+
subcommand = self.config["subcommand"]
|
276 |
+
if subcommand in (*self.subcommands(), "train", "predict"):
|
277 |
+
self.config[subcommand] = update_config(self.config[subcommand])
|
278 |
+
|
279 |
+
def instantiate_classes(self) -> None:
|
280 |
+
"""Instantiate classes depending on the subcommand.
|
281 |
+
|
282 |
+
For trainer related commands it instantiates all the model, datamodule and trainer classes.
|
283 |
+
But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc
|
284 |
+
This is because the subcommand is responsible for instantiating and executing code based on the passed config
|
285 |
+
"""
|
286 |
+
if self.config["subcommand"] in (*self.subcommands(), "predict"): # trainer commands
|
287 |
+
# since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object.
|
288 |
+
# the minor change here is that engine is instantiated instead of trainer
|
289 |
+
self.config_init = self.parser.instantiate_classes(self.config)
|
290 |
+
self.datamodule = self._get(self.config_init, "data")
|
291 |
+
if isinstance(self.datamodule, Dataset):
|
292 |
+
self.datamodule = DataLoader(self.datamodule)
|
293 |
+
self.model = self._get(self.config_init, "model")
|
294 |
+
self._configure_optimizers_method_to_model()
|
295 |
+
self.instantiate_engine()
|
296 |
+
else:
|
297 |
+
self.config_init = self.parser.instantiate_classes(self.config)
|
298 |
+
subcommand = self.config["subcommand"]
|
299 |
+
if subcommand in ("train", "export"):
|
300 |
+
self.instantiate_engine()
|
301 |
+
if "model" in self.config_init[subcommand]:
|
302 |
+
self.model = self._get(self.config_init, "model")
|
303 |
+
else:
|
304 |
+
self.model = None
|
305 |
+
if "data" in self.config_init[subcommand]:
|
306 |
+
self.datamodule = self._get(self.config_init, "data")
|
307 |
+
else:
|
308 |
+
self.datamodule = None
|
309 |
+
|
310 |
+
def instantiate_engine(self) -> None:
|
311 |
+
"""Instantiate the engine.
|
312 |
+
|
313 |
+
.. note::
|
314 |
+
Most of the code in this method is taken from ``LightningCLI``'s
|
315 |
+
``instantiate_trainer`` method. Refer to that method for more
|
316 |
+
details.
|
317 |
+
"""
|
318 |
+
from lightning.pytorch.cli import SaveConfigCallback
|
319 |
+
|
320 |
+
from anomalib.callbacks import get_callbacks
|
321 |
+
|
322 |
+
engine_args = {
|
323 |
+
"normalization": self._get(self.config_init, "normalization.normalization_method"),
|
324 |
+
"threshold": self._get(self.config_init, "metrics.threshold"),
|
325 |
+
"task": self._get(self.config_init, "task"),
|
326 |
+
"image_metrics": self._get(self.config_init, "metrics.image"),
|
327 |
+
"pixel_metrics": self._get(self.config_init, "metrics.pixel"),
|
328 |
+
}
|
329 |
+
trainer_config = {**self._get(self.config_init, "trainer", default={}), **engine_args}
|
330 |
+
key = "callbacks"
|
331 |
+
if key in trainer_config:
|
332 |
+
if trainer_config[key] is None:
|
333 |
+
trainer_config[key] = []
|
334 |
+
elif not isinstance(trainer_config[key], list):
|
335 |
+
trainer_config[key] = [trainer_config[key]]
|
336 |
+
if not trainer_config.get("fast_dev_run", False):
|
337 |
+
config_callback = SaveConfigCallback(
|
338 |
+
self._parser(self.subcommand),
|
339 |
+
self.config.get(str(self.subcommand), self.config),
|
340 |
+
overwrite=True,
|
341 |
+
)
|
342 |
+
trainer_config[key].append(config_callback)
|
343 |
+
trainer_config[key].extend(get_callbacks(self.config[self.subcommand]))
|
344 |
+
self.engine = Engine(**trainer_config)
|
345 |
+
|
346 |
+
def _run_subcommand(self) -> None:
|
347 |
+
"""Run subcommand depending on the subcommand.
|
348 |
+
|
349 |
+
This overrides the original ``_run_subcommand`` to run the ``Engine``
|
350 |
+
method rather than the ``Train`` method.
|
351 |
+
"""
|
352 |
+
if self.subcommand == "install":
|
353 |
+
from anomalib.cli.install import anomalib_install
|
354 |
+
|
355 |
+
install_kwargs = self.config.get("install", {})
|
356 |
+
anomalib_install(**install_kwargs)
|
357 |
+
elif self.config["subcommand"] in (*self.subcommands(), "train", "export", "predict"):
|
358 |
+
fn = getattr(self.engine, self.subcommand)
|
359 |
+
fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand)
|
360 |
+
fn(**fn_kwargs)
|
361 |
+
else:
|
362 |
+
self.config_init = self.parser.instantiate_classes(self.config)
|
363 |
+
getattr(self, f"{self.subcommand}")()
|
364 |
+
|
365 |
+
@property
|
366 |
+
def fit(self) -> Callable:
|
367 |
+
"""Fit the model using engine's fit method."""
|
368 |
+
return self.engine.fit
|
369 |
+
|
370 |
+
@property
|
371 |
+
def validate(self) -> Callable:
|
372 |
+
"""Validate the model using engine's validate method."""
|
373 |
+
return self.engine.validate
|
374 |
+
|
375 |
+
@property
|
376 |
+
def test(self) -> Callable:
|
377 |
+
"""Test the model using engine's test method."""
|
378 |
+
return self.engine.test
|
379 |
+
|
380 |
+
@property
|
381 |
+
def predict(self) -> Callable:
|
382 |
+
"""Predict using engine's predict method."""
|
383 |
+
return self.engine.predict
|
384 |
+
|
385 |
+
@property
|
386 |
+
def train(self) -> Callable:
|
387 |
+
"""Train the model using engine's train method."""
|
388 |
+
return self.engine.train
|
389 |
+
|
390 |
+
@property
|
391 |
+
def export(self) -> Callable:
|
392 |
+
"""Export the model using engine's export method."""
|
393 |
+
return self.engine.export
|
394 |
+
|
395 |
+
def _add_trainer_arguments_to_parser(
|
396 |
+
self,
|
397 |
+
parser: ArgumentParser,
|
398 |
+
add_optimizer: bool = False,
|
399 |
+
add_scheduler: bool = False,
|
400 |
+
) -> None:
|
401 |
+
"""Add trainer arguments to the parser."""
|
402 |
+
parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True)
|
403 |
+
|
404 |
+
if add_optimizer:
|
405 |
+
from torch.optim import Optimizer
|
406 |
+
|
407 |
+
optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
|
408 |
+
parser.add_subclass_arguments(
|
409 |
+
baseclass=(Optimizer,),
|
410 |
+
nested_key="optimizer",
|
411 |
+
**optim_kwargs,
|
412 |
+
)
|
413 |
+
if add_scheduler:
|
414 |
+
from lightning.pytorch.cli import LRSchedulerTypeTuple
|
415 |
+
|
416 |
+
scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
|
417 |
+
parser.add_subclass_arguments(
|
418 |
+
baseclass=LRSchedulerTypeTuple,
|
419 |
+
nested_key="lr_scheduler",
|
420 |
+
**scheduler_kwargs,
|
421 |
+
)
|
422 |
+
|
423 |
+
def _add_default_arguments_to_parser(self, parser: ArgumentParser) -> None:
|
424 |
+
"""Adds default arguments to the parser."""
|
425 |
+
parser.add_argument(
|
426 |
+
"--seed_everything",
|
427 |
+
type=bool | int,
|
428 |
+
default=True,
|
429 |
+
help=(
|
430 |
+
"Set to an int to run seed_everything with this value before classes instantiation."
|
431 |
+
"Set to True to use a random seed."
|
432 |
+
),
|
433 |
+
)
|
434 |
+
|
435 |
+
def _get(self, config: Namespace, key: str, default: Any = None) -> Any: # noqa: ANN401
|
436 |
+
"""Utility to get a config value which might be inside a subcommand."""
|
437 |
+
return config.get(str(self.subcommand), config).get(key, default)
|
438 |
+
|
439 |
+
def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]:
|
440 |
+
"""Prepares the keyword arguments to pass to the subcommand to run."""
|
441 |
+
fn_kwargs = {
|
442 |
+
k: v for k, v in self.config_init[subcommand].items() if k in self.subcommand_method_arguments[subcommand]
|
443 |
+
}
|
444 |
+
fn_kwargs["model"] = self.model
|
445 |
+
if self.datamodule is not None:
|
446 |
+
if isinstance(self.datamodule, AnomalibDataModule):
|
447 |
+
fn_kwargs["datamodule"] = self.datamodule
|
448 |
+
elif isinstance(self.datamodule, DataLoader):
|
449 |
+
fn_kwargs["dataloaders"] = self.datamodule
|
450 |
+
elif isinstance(self.datamodule, Path | str):
|
451 |
+
fn_kwargs["data_path"] = self.datamodule
|
452 |
+
return fn_kwargs
|
453 |
+
|
454 |
+
def _parser(self, subcommand: str | None) -> ArgumentParser:
|
455 |
+
if subcommand is None:
|
456 |
+
return self.parser
|
457 |
+
# return the subcommand parser for the subcommand passed
|
458 |
+
return self.subcommand_parsers[subcommand]
|
459 |
+
|
460 |
+
def _configure_optimizers_method_to_model(self) -> None:
|
461 |
+
from lightning.pytorch.cli import LightningCLI, instantiate_class
|
462 |
+
|
463 |
+
optimizer_cfg = self._get(self.config_init, "optimizer", None)
|
464 |
+
if optimizer_cfg is None:
|
465 |
+
return
|
466 |
+
lr_scheduler_cfg = self._get(self.config_init, "lr_scheduler", {})
|
467 |
+
|
468 |
+
optimizer = instantiate_class(self.model.parameters(), optimizer_cfg)
|
469 |
+
lr_scheduler = instantiate_class(optimizer, lr_scheduler_cfg) if lr_scheduler_cfg else None
|
470 |
+
fn = partial(LightningCLI.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
471 |
+
|
472 |
+
# override the existing method
|
473 |
+
self.model.configure_optimizers = MethodType(fn, self.model)
|
474 |
+
|
475 |
+
|
476 |
+
def main() -> None:
|
477 |
+
"""Trainer via Anomalib CLI."""
|
478 |
+
configure_logger()
|
479 |
+
AnomalibCLI()
|
480 |
+
|
481 |
+
|
482 |
+
if __name__ == "__main__":
|
483 |
+
main()
|
anomalib/cli/install.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib install subcommand code."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from pkg_resources import Requirement
|
9 |
+
from rich.console import Console
|
10 |
+
from rich.logging import RichHandler
|
11 |
+
|
12 |
+
from anomalib.cli.utils.installation import (
|
13 |
+
get_requirements,
|
14 |
+
get_torch_install_args,
|
15 |
+
parse_requirements,
|
16 |
+
)
|
17 |
+
|
18 |
+
logger = logging.getLogger("pip")
|
19 |
+
logger.setLevel(logging.WARNING) # setLevel: CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
|
20 |
+
console = Console()
|
21 |
+
handler = RichHandler(
|
22 |
+
console=console,
|
23 |
+
show_level=False,
|
24 |
+
show_path=False,
|
25 |
+
)
|
26 |
+
logger.addHandler(handler)
|
27 |
+
|
28 |
+
|
29 |
+
def anomalib_install(option: str = "full", verbose: bool = False) -> int:
|
30 |
+
"""Install Anomalib requirements.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
option (str | None): Optional-dependency to install requirements for.
|
34 |
+
verbose (bool): Set pip logger level to INFO
|
35 |
+
|
36 |
+
Raises:
|
37 |
+
ValueError: When the task is not supported.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
int: Status code of the pip install command.
|
41 |
+
"""
|
42 |
+
from pip._internal.commands import create_command
|
43 |
+
|
44 |
+
requirements_dict = get_requirements("anomalib")
|
45 |
+
|
46 |
+
requirements = []
|
47 |
+
if option == "full":
|
48 |
+
for extra in requirements_dict:
|
49 |
+
requirements.extend(requirements_dict[extra])
|
50 |
+
elif option in requirements_dict:
|
51 |
+
requirements.extend(requirements_dict[option])
|
52 |
+
elif option is not None:
|
53 |
+
requirements.append(Requirement.parse(option))
|
54 |
+
|
55 |
+
# Parse requirements into torch and other requirements.
|
56 |
+
# This is done to parse the correct version of torch (cpu/cuda).
|
57 |
+
torch_requirement, other_requirements = parse_requirements(requirements, skip_torch=option not in ("full", "core"))
|
58 |
+
|
59 |
+
# Get install args for torch to install it from a specific index-url
|
60 |
+
install_args: list[str] = []
|
61 |
+
torch_install_args = []
|
62 |
+
if option in ("full", "core") and torch_requirement is not None:
|
63 |
+
torch_install_args = get_torch_install_args(torch_requirement)
|
64 |
+
|
65 |
+
# Combine torch and other requirements.
|
66 |
+
install_args = other_requirements + torch_install_args
|
67 |
+
|
68 |
+
# Install requirements.
|
69 |
+
with console.status("[bold green]Installing packages... This may take a few minutes.\n") as status:
|
70 |
+
if verbose:
|
71 |
+
logger.setLevel(logging.INFO)
|
72 |
+
status.stop()
|
73 |
+
console.log(f"Installation list: [yellow]{install_args}[/yellow]")
|
74 |
+
status_code = create_command("install").main(install_args)
|
75 |
+
if status_code == 0:
|
76 |
+
console.log(f"Installation Complete: {install_args}")
|
77 |
+
|
78 |
+
if status_code == 0:
|
79 |
+
console.print("Anomalib Installation [bold green]Complete.[/bold green]")
|
80 |
+
|
81 |
+
return status_code
|
anomalib/cli/utils/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib CLI Utils."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from .help_formatter import CustomHelpFormatter
|
7 |
+
|
8 |
+
__all__ = ["CustomHelpFormatter"]
|
anomalib/cli/utils/help_formatter.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Help Formatters for Anomalib CLI."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
from typing import TypeVar
|
10 |
+
|
11 |
+
import docstring_parser
|
12 |
+
from jsonargparse import DefaultHelpFormatter
|
13 |
+
from rich.markdown import Markdown
|
14 |
+
from rich.panel import Panel
|
15 |
+
from rich_argparse import RichHelpFormatter
|
16 |
+
|
17 |
+
REQUIRED_ARGUMENTS = {
|
18 |
+
"train": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
|
19 |
+
"fit": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
|
20 |
+
"validate": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
|
21 |
+
"test": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
|
22 |
+
"predict": {"model", "model.help", "data", "data.help", "ckpt_path", "config"},
|
23 |
+
"export": {"model", "model.help", "export_type", "ckpt_path", "config"},
|
24 |
+
}
|
25 |
+
|
26 |
+
try:
|
27 |
+
from anomalib.engine import Engine
|
28 |
+
|
29 |
+
DOCSTRING_USAGE = {
|
30 |
+
"train": Engine.train,
|
31 |
+
"fit": Engine.fit,
|
32 |
+
"validate": Engine.validate,
|
33 |
+
"test": Engine.test,
|
34 |
+
"predict": Engine.predict,
|
35 |
+
"export": Engine.export,
|
36 |
+
}
|
37 |
+
except ImportError:
|
38 |
+
print("To use other subcommand using `anomalib install`")
|
39 |
+
|
40 |
+
|
41 |
+
def get_short_docstring(component: TypeVar) -> str:
|
42 |
+
"""Get the short description from the docstring.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
component (TypeVar): The component to get the docstring from
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: The short description
|
49 |
+
"""
|
50 |
+
if component.__doc__ is None:
|
51 |
+
return ""
|
52 |
+
docstring = docstring_parser.parse(component.__doc__)
|
53 |
+
return docstring.short_description
|
54 |
+
|
55 |
+
|
56 |
+
def get_verbosity_subcommand() -> dict:
|
57 |
+
"""Parse command line arguments and returns a dictionary of key-value pairs.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
A dictionary containing the parsed command line arguments.
|
61 |
+
|
62 |
+
Examples:
|
63 |
+
>>> import sys
|
64 |
+
>>> sys.argv = ['anomalib', 'train', '-h', '-v']
|
65 |
+
>>> get_verbosity_subcommand()
|
66 |
+
{'subcommand': 'train', 'help': True, 'verbosity': 1}
|
67 |
+
"""
|
68 |
+
arguments: dict = {"subcommand": None, "help": False, "verbosity": 2}
|
69 |
+
if len(sys.argv) >= 2 and sys.argv[1] not in ("--help", "-h"):
|
70 |
+
arguments["subcommand"] = sys.argv[1]
|
71 |
+
if "--help" in sys.argv or "-h" in sys.argv:
|
72 |
+
arguments["help"] = True
|
73 |
+
if arguments["subcommand"] in REQUIRED_ARGUMENTS:
|
74 |
+
arguments["verbosity"] = 0
|
75 |
+
if "-v" in sys.argv or "--verbose" in sys.argv:
|
76 |
+
arguments["verbosity"] = 1
|
77 |
+
if "-vv" in sys.argv:
|
78 |
+
arguments["verbosity"] = 2
|
79 |
+
return arguments
|
80 |
+
|
81 |
+
|
82 |
+
def get_intro() -> Markdown:
|
83 |
+
"""Return a Markdown object containing the introduction text for Anomalib CLI Guide.
|
84 |
+
|
85 |
+
The introduction text includes a brief description of the guide and links to the Github repository and documentation
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
A Markdown object containing the introduction text for Anomalib CLI Guide.
|
89 |
+
"""
|
90 |
+
intro_markdown = (
|
91 |
+
"# Anomalib CLI Guide\n\n"
|
92 |
+
"Github Repository: [https://github.com/openvinotoolkit/anomalib](https://github.com/openvinotoolkit/anomalib)."
|
93 |
+
"\n\n"
|
94 |
+
"A better guide is provided by the [documentation](https://anomalib.readthedocs.io/en/latest/index.html)."
|
95 |
+
)
|
96 |
+
return Markdown(intro_markdown)
|
97 |
+
|
98 |
+
|
99 |
+
def get_verbose_usage(subcommand: str = "train") -> str:
|
100 |
+
"""Return a string containing verbose usage information for the specified subcommand.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
----
|
104 |
+
subcommand (str): The name of the subcommand to get verbose usage information for. Defaults to "train".
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
-------
|
108 |
+
str: A string containing verbose usage information for the specified subcommand.
|
109 |
+
"""
|
110 |
+
return (
|
111 |
+
"To get more overridable argument information, run the command below.\n"
|
112 |
+
"```python\n"
|
113 |
+
"# Verbosity Level 1\n"
|
114 |
+
f"anomalib {subcommand} [optional_arguments] -h -v\n"
|
115 |
+
"# Verbosity Level 2\n"
|
116 |
+
f"anomalib {subcommand} [optional_arguments] -h -vv\n"
|
117 |
+
"```"
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
def get_cli_usage_docstring(component: object | None) -> str | None:
|
122 |
+
r"""Get the cli usage from the docstring.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
----
|
126 |
+
component (Optional[object]): The component to get the docstring from
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
-------
|
130 |
+
Optional[str]: The quick-start guide as Markdown format.
|
131 |
+
|
132 |
+
Example:
|
133 |
+
-------
|
134 |
+
component.__doc__ = '''
|
135 |
+
<Prev Section>
|
136 |
+
|
137 |
+
CLI Usage:
|
138 |
+
1. First Step.
|
139 |
+
2. Second Step.
|
140 |
+
|
141 |
+
<Next Section>
|
142 |
+
'''
|
143 |
+
>>> get_cli_usage_docstring(component)
|
144 |
+
"1. First Step.\n2. Second Step."
|
145 |
+
"""
|
146 |
+
if component is None or component.__doc__ is None or "CLI Usage" not in component.__doc__:
|
147 |
+
return None
|
148 |
+
|
149 |
+
pattern = r"CLI Usage:(.*?)(?=\n{2,}|\Z)"
|
150 |
+
match = re.search(pattern, component.__doc__, re.DOTALL)
|
151 |
+
|
152 |
+
if match:
|
153 |
+
contents = match.group(1).strip().split("\n")
|
154 |
+
return "\n".join([content.strip() for content in contents])
|
155 |
+
return None
|
156 |
+
|
157 |
+
|
158 |
+
def render_guide(subcommand: str | None = None) -> list:
|
159 |
+
"""Render a guide for the specified subcommand.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
----
|
163 |
+
subcommand (Optional[str]): The subcommand to render the guide for.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
-------
|
167 |
+
list: A list of contents to be displayed in the guide.
|
168 |
+
"""
|
169 |
+
if subcommand is None or subcommand not in DOCSTRING_USAGE:
|
170 |
+
return []
|
171 |
+
contents = [get_intro()]
|
172 |
+
target_command = DOCSTRING_USAGE[subcommand]
|
173 |
+
cli_usage = get_cli_usage_docstring(target_command)
|
174 |
+
if cli_usage is not None:
|
175 |
+
cli_usage += f"\n{get_verbose_usage(subcommand)}"
|
176 |
+
quick_start = Panel(Markdown(cli_usage), border_style="dim", title="Quick-Start", title_align="left")
|
177 |
+
contents.append(quick_start)
|
178 |
+
return contents
|
179 |
+
|
180 |
+
|
181 |
+
class CustomHelpFormatter(RichHelpFormatter, DefaultHelpFormatter):
|
182 |
+
"""A custom help formatter for Anomalib CLI.
|
183 |
+
|
184 |
+
This formatter extends the RichHelpFormatter and DefaultHelpFormatter classes to provide
|
185 |
+
a more detailed and customizable help output for Anomalib CLI.
|
186 |
+
|
187 |
+
Attributes:
|
188 |
+
verbosity_level : int
|
189 |
+
The level of verbosity for the help output.
|
190 |
+
subcommand : str | None
|
191 |
+
The subcommand to render the guide for.
|
192 |
+
|
193 |
+
Methods:
|
194 |
+
add_usage(usage, actions, *args, **kwargs)
|
195 |
+
Add usage information to the help output.
|
196 |
+
add_argument(action)
|
197 |
+
Add an argument to the help output.
|
198 |
+
format_help()
|
199 |
+
Format the help output.
|
200 |
+
"""
|
201 |
+
|
202 |
+
verbosity_dict = get_verbosity_subcommand()
|
203 |
+
verbosity_level = verbosity_dict["verbosity"]
|
204 |
+
subcommand = verbosity_dict["subcommand"]
|
205 |
+
|
206 |
+
def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None:
|
207 |
+
"""Add usage information to the formatter.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
----
|
211 |
+
usage (str | None): A string describing the usage of the program.
|
212 |
+
actions (list): An list of argparse.Action objects.
|
213 |
+
*args (Any): Additional positional arguments to pass to the superclass method.
|
214 |
+
**kwargs (Any): Additional keyword arguments to pass to the superclass method.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
-------
|
218 |
+
None
|
219 |
+
"""
|
220 |
+
if self.subcommand in REQUIRED_ARGUMENTS:
|
221 |
+
if self.verbosity_level == 0:
|
222 |
+
actions = []
|
223 |
+
elif self.verbosity_level == 1:
|
224 |
+
actions = [action for action in actions if action.dest in REQUIRED_ARGUMENTS[self.subcommand]]
|
225 |
+
|
226 |
+
super().add_usage(usage, actions, *args, **kwargs)
|
227 |
+
|
228 |
+
def add_argument(self, action: argparse.Action) -> None:
|
229 |
+
"""Add an argument to the help formatter.
|
230 |
+
|
231 |
+
If the verbose level is set to 0, the argument is not added.
|
232 |
+
If the verbose level is set to 1 and the argument is not in the non-skip list, the argument is not added.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
----
|
236 |
+
action (argparse.Action): The action to add to the help formatter.
|
237 |
+
"""
|
238 |
+
if self.subcommand in REQUIRED_ARGUMENTS:
|
239 |
+
if self.verbosity_level == 0:
|
240 |
+
return
|
241 |
+
if self.verbosity_level == 1 and action.dest not in REQUIRED_ARGUMENTS[self.subcommand]:
|
242 |
+
return
|
243 |
+
super().add_argument(action)
|
244 |
+
|
245 |
+
def format_help(self) -> str:
|
246 |
+
"""Format the help message for the current command and returns it as a string.
|
247 |
+
|
248 |
+
The help message includes information about the command's arguments and options,
|
249 |
+
as well as any additional information provided by the command's help guide.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
str: A string containing the formatted help message.
|
253 |
+
"""
|
254 |
+
with self.console.capture() as capture:
|
255 |
+
section = self._root_section
|
256 |
+
if self.subcommand in REQUIRED_ARGUMENTS and self.verbosity_level in (0, 1) and len(section.rich_items) > 1:
|
257 |
+
contents = render_guide(self.subcommand)
|
258 |
+
for content in contents:
|
259 |
+
self.console.print(content)
|
260 |
+
if self.verbosity_level > 0:
|
261 |
+
if len(section.rich_items) > 1:
|
262 |
+
section = Panel(section, border_style="dim", title="Arguments", title_align="left")
|
263 |
+
self.console.print(section, highlight=False, soft_wrap=True)
|
264 |
+
help_msg = capture.get()
|
265 |
+
|
266 |
+
if help_msg:
|
267 |
+
help_msg = self._long_break_matcher.sub("\n\n", help_msg).rstrip() + "\n"
|
268 |
+
return help_msg
|
anomalib/cli/utils/installation.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib installation util functions."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import platform
|
11 |
+
import re
|
12 |
+
from importlib.metadata import requires
|
13 |
+
from pathlib import Path
|
14 |
+
from warnings import warn
|
15 |
+
|
16 |
+
from pkg_resources import Requirement
|
17 |
+
|
18 |
+
AVAILABLE_TORCH_VERSIONS = {
|
19 |
+
"2.0.0": {"torchvision": "0.15.1", "cuda": ("11.7", "11.8")},
|
20 |
+
"2.0.1": {"torchvision": "0.15.2", "cuda": ("11.7", "11.8")},
|
21 |
+
"2.1.1": {"torchvision": "0.16.1", "cuda": ("11.8", "12.1")},
|
22 |
+
"2.1.2": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")},
|
23 |
+
"2.2.0": {"torchvision": "0.16.2", "cuda": ("11.8", "12.1")},
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def get_requirements(module: str = "anomalib") -> dict[str, list[Requirement]]:
|
28 |
+
"""Get requirements of module from importlib.metadata.
|
29 |
+
|
30 |
+
This function returns list of required packages from importlib_metadata.
|
31 |
+
|
32 |
+
Example:
|
33 |
+
>>> get_requirements("anomalib")
|
34 |
+
{
|
35 |
+
"base": ["jsonargparse==4.27.1", ...],
|
36 |
+
"core": ["torch==2.1.1", ...],
|
37 |
+
...
|
38 |
+
}
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
dict[str, list[Requirement]]: List of required packages for each optional-extras.
|
42 |
+
"""
|
43 |
+
requirement_list: list[str] | None = requires(module)
|
44 |
+
extra_requirement: dict[str, list[Requirement]] = {}
|
45 |
+
if requirement_list is None:
|
46 |
+
return extra_requirement
|
47 |
+
for requirement in requirement_list:
|
48 |
+
extra = "core"
|
49 |
+
requirement_extra: list[str] = requirement.replace(" ", "").split(";")
|
50 |
+
if isinstance(requirement_extra, list) and len(requirement_extra) > 1:
|
51 |
+
extra = requirement_extra[-1].split("==")[-1].strip("'\"")
|
52 |
+
_requirement_name = requirement_extra[0]
|
53 |
+
_requirement = Requirement.parse(_requirement_name)
|
54 |
+
if extra in extra_requirement:
|
55 |
+
extra_requirement[extra].append(_requirement)
|
56 |
+
else:
|
57 |
+
extra_requirement[extra] = [_requirement]
|
58 |
+
return extra_requirement
|
59 |
+
|
60 |
+
|
61 |
+
def parse_requirements(
|
62 |
+
requirements: list[Requirement],
|
63 |
+
skip_torch: bool = False,
|
64 |
+
) -> tuple[str | None, list[str]]:
|
65 |
+
"""Parse requirements and returns torch and other requirements.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
requirements (list[Requirement]): List of requirements.
|
69 |
+
skip_torch (bool): Whether to skip torch requirement. Defaults to False.
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
ValueError: If torch requirement is not found.
|
73 |
+
|
74 |
+
Examples:
|
75 |
+
>>> requirements = [
|
76 |
+
... Requirement.parse("torch==1.13.0"),
|
77 |
+
... Requirement.parse("onnx>=1.8.1"),
|
78 |
+
... ]
|
79 |
+
>>> parse_requirements(requirements=requirements)
|
80 |
+
(Requirement.parse("torch==1.13.0"),
|
81 |
+
Requirement.parse("onnx>=1.8.1"))
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
tuple[str, list[str], list[str]]: Tuple of torch and other requirements.
|
85 |
+
"""
|
86 |
+
torch_requirement: str | None = None
|
87 |
+
other_requirements: list[str] = []
|
88 |
+
|
89 |
+
for requirement in requirements:
|
90 |
+
if requirement.unsafe_name == "torch":
|
91 |
+
torch_requirement = str(requirement)
|
92 |
+
if len(requirement.specs) > 1:
|
93 |
+
warn(
|
94 |
+
"requirements.txt contains. Please remove other versions of torch from requirements.",
|
95 |
+
stacklevel=2,
|
96 |
+
)
|
97 |
+
|
98 |
+
# Rest of the requirements are task requirements.
|
99 |
+
# Other torch-related requirements such as `torchvision` are to be excluded.
|
100 |
+
# This is because torch-related requirements are already handled in torch_requirement.
|
101 |
+
else:
|
102 |
+
# if not requirement.unsafe_name.startswith("torch"):
|
103 |
+
other_requirements.append(str(requirement))
|
104 |
+
|
105 |
+
if not skip_torch and not torch_requirement:
|
106 |
+
msg = "Could not find torch requirement. Anoamlib depends on torch. Please add torch to your requirements."
|
107 |
+
raise ValueError(msg)
|
108 |
+
|
109 |
+
# Get the unique list of the requirements.
|
110 |
+
other_requirements = list(set(other_requirements))
|
111 |
+
|
112 |
+
return torch_requirement, other_requirements
|
113 |
+
|
114 |
+
|
115 |
+
def get_cuda_version() -> str | None:
|
116 |
+
"""Get CUDA version installed on the system.
|
117 |
+
|
118 |
+
Examples:
|
119 |
+
>>> # Assume that CUDA version is 11.2
|
120 |
+
>>> get_cuda_version()
|
121 |
+
"11.2"
|
122 |
+
|
123 |
+
>>> # Assume that CUDA is not installed on the system
|
124 |
+
>>> get_cuda_version()
|
125 |
+
None
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
str | None: CUDA version installed on the system.
|
129 |
+
"""
|
130 |
+
# 1. Check CUDA_HOME Environment variable
|
131 |
+
cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda")
|
132 |
+
|
133 |
+
if Path(cuda_home).exists():
|
134 |
+
# Check $CUDA_HOME/version.json file.
|
135 |
+
version_file = Path(cuda_home) / "version.json"
|
136 |
+
if version_file.is_file():
|
137 |
+
with Path(version_file).open() as file:
|
138 |
+
data = json.load(file)
|
139 |
+
cuda_version = data.get("cuda", {}).get("version", None)
|
140 |
+
if cuda_version is not None:
|
141 |
+
cuda_version_parts = cuda_version.split(".")
|
142 |
+
return ".".join(cuda_version_parts[:2])
|
143 |
+
# 2. 'nvcc --version' check & without version.json case
|
144 |
+
try:
|
145 |
+
result = os.popen(cmd="nvcc --version")
|
146 |
+
output = result.read()
|
147 |
+
|
148 |
+
cuda_version_pattern = r"cuda_(\d+\.\d+)"
|
149 |
+
cuda_version_match = re.search(cuda_version_pattern, output)
|
150 |
+
|
151 |
+
if cuda_version_match is not None:
|
152 |
+
return cuda_version_match.group(1)
|
153 |
+
except OSError:
|
154 |
+
msg = "Could not find cuda-version. Instead, the CPU version of torch will be installed."
|
155 |
+
warn(msg, stacklevel=2)
|
156 |
+
return None
|
157 |
+
|
158 |
+
|
159 |
+
def update_cuda_version_with_available_torch_cuda_build(cuda_version: str, torch_version: str) -> str:
|
160 |
+
"""Update the installed CUDA version with the highest supported CUDA version by PyTorch.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
cuda_version (str): The installed CUDA version.
|
164 |
+
torch_version (str): The PyTorch version.
|
165 |
+
|
166 |
+
Raises:
|
167 |
+
Warning: If the installed CUDA version is not supported by PyTorch.
|
168 |
+
|
169 |
+
Examples:
|
170 |
+
>>> update_cuda_version_with_available_torch_cuda_builds("11.1", "1.13.0")
|
171 |
+
"11.6"
|
172 |
+
|
173 |
+
>>> update_cuda_version_with_available_torch_cuda_builds("11.7", "1.13.0")
|
174 |
+
"11.7"
|
175 |
+
|
176 |
+
>>> update_cuda_version_with_available_torch_cuda_builds("11.8", "1.13.0")
|
177 |
+
"11.7"
|
178 |
+
|
179 |
+
>>> update_cuda_version_with_available_torch_cuda_builds("12.1", "2.0.1")
|
180 |
+
"11.8"
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
str: The updated CUDA version.
|
184 |
+
"""
|
185 |
+
max_supported_cuda = max(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"])
|
186 |
+
min_supported_cuda = min(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"])
|
187 |
+
bounded_cuda_version = max(min(cuda_version, max_supported_cuda), min_supported_cuda)
|
188 |
+
|
189 |
+
if cuda_version != bounded_cuda_version:
|
190 |
+
warn(
|
191 |
+
f"Installed CUDA version is v{cuda_version}. \n"
|
192 |
+
f"v{min_supported_cuda} <= Supported CUDA version <= v{max_supported_cuda}.\n"
|
193 |
+
f"This script will use CUDA v{bounded_cuda_version}.\n"
|
194 |
+
f"However, this may not be safe, and you are advised to install the correct version of CUDA.\n"
|
195 |
+
f"For more details, refer to https://pytorch.org/get-started/locally/",
|
196 |
+
stacklevel=2,
|
197 |
+
)
|
198 |
+
cuda_version = bounded_cuda_version
|
199 |
+
|
200 |
+
return cuda_version
|
201 |
+
|
202 |
+
|
203 |
+
def get_cuda_suffix(cuda_version: str) -> str:
|
204 |
+
"""Get CUDA suffix for PyTorch versions.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
cuda_version (str): CUDA version installed on the system.
|
208 |
+
|
209 |
+
Note:
|
210 |
+
The CUDA version of PyTorch is not always the same as the CUDA version
|
211 |
+
that is installed on the system. For example, the latest PyTorch
|
212 |
+
version (1.10.0) supports CUDA 11.3, but the latest CUDA version
|
213 |
+
that is available for download is 11.2. Therefore, we need to use
|
214 |
+
the latest available CUDA version for PyTorch instead of the CUDA
|
215 |
+
version that is installed on the system. Therefore, this function
|
216 |
+
shoudl be regularly updated to reflect the latest available CUDA.
|
217 |
+
|
218 |
+
Examples:
|
219 |
+
>>> get_cuda_suffix(cuda_version="11.2")
|
220 |
+
"cu112"
|
221 |
+
|
222 |
+
>>> get_cuda_suffix(cuda_version="11.8")
|
223 |
+
"cu118"
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
str: CUDA suffix for PyTorch or mmX version.
|
227 |
+
"""
|
228 |
+
return f"cu{cuda_version.replace('.', '')}"
|
229 |
+
|
230 |
+
|
231 |
+
def get_hardware_suffix(with_available_torch_build: bool = False, torch_version: str | None = None) -> str:
|
232 |
+
"""Get hardware suffix for PyTorch or mmX versions.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
with_available_torch_build (bool): Whether to use the latest available
|
236 |
+
PyTorch build or not. If True, the latest available PyTorch build
|
237 |
+
will be used. If False, the installed PyTorch build will be used.
|
238 |
+
Defaults to False.
|
239 |
+
torch_version (str | None): PyTorch version. This is only used when the
|
240 |
+
``with_available_torch_build`` is True.
|
241 |
+
|
242 |
+
Examples:
|
243 |
+
>>> # Assume that CUDA version is 11.2
|
244 |
+
>>> get_hardware_suffix()
|
245 |
+
"cu112"
|
246 |
+
|
247 |
+
>>> # Assume that CUDA is not installed on the system
|
248 |
+
>>> get_hardware_suffix()
|
249 |
+
"cpu"
|
250 |
+
|
251 |
+
Assume that that installed CUDA version is 12.1.
|
252 |
+
However, the latest available CUDA version for PyTorch v2.0 is 11.8.
|
253 |
+
Therefore, we use 11.8 instead of 12.1. This is because PyTorch does not
|
254 |
+
support CUDA 12.1 yet. In this case, we could correct the CUDA version
|
255 |
+
by setting `with_available_torch_build` to True.
|
256 |
+
|
257 |
+
>>> cuda_version = get_cuda_version()
|
258 |
+
"12.1"
|
259 |
+
>>> get_hardware_suffix(with_available_torch_build=True, torch_version="2.0.1")
|
260 |
+
"cu118"
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
str: Hardware suffix for PyTorch or mmX version.
|
264 |
+
"""
|
265 |
+
cuda_version = get_cuda_version()
|
266 |
+
if cuda_version:
|
267 |
+
if with_available_torch_build:
|
268 |
+
if torch_version is None:
|
269 |
+
msg = "``torch_version`` must be provided when with_available_torch_build is True."
|
270 |
+
raise ValueError(msg)
|
271 |
+
cuda_version = update_cuda_version_with_available_torch_cuda_build(cuda_version, torch_version)
|
272 |
+
hardware_suffix = get_cuda_suffix(cuda_version)
|
273 |
+
else:
|
274 |
+
hardware_suffix = "cpu"
|
275 |
+
|
276 |
+
return hardware_suffix
|
277 |
+
|
278 |
+
|
279 |
+
def add_hardware_suffix_to_torch(
|
280 |
+
requirement: Requirement,
|
281 |
+
hardware_suffix: str | None = None,
|
282 |
+
with_available_torch_build: bool = False,
|
283 |
+
) -> str:
|
284 |
+
"""Add hardware suffix to the torch requirement.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
requirement (Requirement): Requirement object comprising requirement
|
288 |
+
details.
|
289 |
+
hardware_suffix (str | None): Hardware suffix. If None, it will be set
|
290 |
+
to the correct hardware suffix. Defaults to None.
|
291 |
+
with_available_torch_build (bool): To check whether the installed
|
292 |
+
CUDA version is supported by the latest available PyTorch build.
|
293 |
+
Defaults to False.
|
294 |
+
|
295 |
+
Examples:
|
296 |
+
>>> from pkg_resources import Requirement
|
297 |
+
>>> req = "torch>=1.13.0, <=2.0.1"
|
298 |
+
>>> requirement = Requirement.parse(req)
|
299 |
+
>>> requirement.name, requirement.specs
|
300 |
+
('torch', [('>=', '1.13.0'), ('<=', '2.0.1')])
|
301 |
+
|
302 |
+
>>> add_hardware_suffix_to_torch(requirement)
|
303 |
+
'torch>=1.13.0+cu121, <=2.0.1+cu121'
|
304 |
+
|
305 |
+
``with_available_torch_build=True`` will use the latest available PyTorch build.
|
306 |
+
>>> req = "torch==2.0.1"
|
307 |
+
>>> requirement = Requirement.parse(req)
|
308 |
+
>>> add_hardware_suffix_to_torch(requirement, with_available_torch_build=True)
|
309 |
+
'torch==2.0.1+cu118'
|
310 |
+
|
311 |
+
It is possible to pass the ``hardware_suffix`` manually.
|
312 |
+
>>> req = "torch==2.0.1"
|
313 |
+
>>> requirement = Requirement.parse(req)
|
314 |
+
>>> add_hardware_suffix_to_torch(requirement, hardware_suffix="cu121")
|
315 |
+
'torch==2.0.1+cu111'
|
316 |
+
|
317 |
+
Raises:
|
318 |
+
ValueError: When the requirement has more than two version criterion.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
str: Updated torch package with the right cuda suffix.
|
322 |
+
"""
|
323 |
+
name = requirement.unsafe_name
|
324 |
+
updated_specs: list[str] = []
|
325 |
+
|
326 |
+
for operator, version in requirement.specs:
|
327 |
+
hardware_suffix = hardware_suffix or get_hardware_suffix(with_available_torch_build, version)
|
328 |
+
updated_version = version + f"+{hardware_suffix}" if not version.startswith(("2.1", "2.2")) else version
|
329 |
+
|
330 |
+
# ``specs`` contains operators and versions as follows:
|
331 |
+
# These are to be concatenated again for the updated version.
|
332 |
+
updated_specs.append(operator + updated_version)
|
333 |
+
|
334 |
+
updated_requirement: str = ""
|
335 |
+
|
336 |
+
if updated_specs:
|
337 |
+
# This is the case when specs are e.g. ['<=1.9.1+cu111']
|
338 |
+
if len(updated_specs) == 1:
|
339 |
+
updated_requirement = name + updated_specs[0]
|
340 |
+
# This is the case when specs are e.g., ['<=1.9.1+cu111', '>=1.8.1+cu111']
|
341 |
+
elif len(updated_specs) == 2:
|
342 |
+
updated_requirement = name + updated_specs[0] + ", " + updated_specs[1]
|
343 |
+
else:
|
344 |
+
msg = (
|
345 |
+
"Requirement version can be a single value or a range. \n"
|
346 |
+
"For example it could be torch>=1.8.1 "
|
347 |
+
"or torch>=1.8.1, <=1.9.1\n"
|
348 |
+
f"Got {updated_specs} instead."
|
349 |
+
)
|
350 |
+
raise ValueError(msg)
|
351 |
+
return updated_requirement
|
352 |
+
|
353 |
+
|
354 |
+
def get_torch_install_args(requirement: str | Requirement) -> list[str]:
|
355 |
+
"""Get the install arguments for Torch requirement.
|
356 |
+
|
357 |
+
This function will return the install arguments for the Torch requirement
|
358 |
+
and its corresponding torchvision requirement.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
requirement (str | Requirement): The torch requirement.
|
362 |
+
|
363 |
+
Raises:
|
364 |
+
RuntimeError: If the OS is not supported.
|
365 |
+
|
366 |
+
Example:
|
367 |
+
>>> from pkg_resources import Requirement
|
368 |
+
>>> requriment = "torch>=1.13.0"
|
369 |
+
>>> get_torch_install_args(requirement)
|
370 |
+
['--extra-index-url', 'https://download.pytorch.org/whl/cpu',
|
371 |
+
'torch==1.13.0+cpu', 'torchvision==0.14.0+cpu']
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
list[str]: The install arguments.
|
375 |
+
"""
|
376 |
+
if isinstance(requirement, str):
|
377 |
+
requirement = Requirement.parse(requirement)
|
378 |
+
|
379 |
+
# NOTE: This does not take into account if the requirement has multiple versions
|
380 |
+
# such as torch<2.0.1,>=1.13.0
|
381 |
+
if len(requirement.specs) < 1:
|
382 |
+
return [str(requirement)]
|
383 |
+
select_spec_idx = 0
|
384 |
+
for i, spec in enumerate(requirement.specs):
|
385 |
+
if "=" in spec[0]:
|
386 |
+
select_spec_idx = i
|
387 |
+
break
|
388 |
+
operator, version = requirement.specs[select_spec_idx]
|
389 |
+
if version not in AVAILABLE_TORCH_VERSIONS:
|
390 |
+
version = max(AVAILABLE_TORCH_VERSIONS.keys())
|
391 |
+
warn(
|
392 |
+
f"Torch Version will be selected as {version}.",
|
393 |
+
stacklevel=2,
|
394 |
+
)
|
395 |
+
install_args: list[str] = []
|
396 |
+
|
397 |
+
if platform.system() in ("Linux", "Windows"):
|
398 |
+
# Get the hardware suffix (eg., +cpu, +cu116 and +cu118 etc.)
|
399 |
+
hardware_suffix = get_hardware_suffix(with_available_torch_build=True, torch_version=version)
|
400 |
+
|
401 |
+
# Create the PyTorch Index URL to download the correct wheel.
|
402 |
+
index_url = f"https://download.pytorch.org/whl/{hardware_suffix}"
|
403 |
+
|
404 |
+
# Create the PyTorch version depending on the CUDA version. For example,
|
405 |
+
# If CUDA version is 11.2, then the PyTorch version is 1.8.0+cu112.
|
406 |
+
# If CUDA version is None, then the PyTorch version is 1.8.0+cpu.
|
407 |
+
torch_version = add_hardware_suffix_to_torch(requirement, hardware_suffix, with_available_torch_build=True)
|
408 |
+
|
409 |
+
# Get the torchvision version depending on the torch version.
|
410 |
+
torchvision_version = AVAILABLE_TORCH_VERSIONS[version]["torchvision"]
|
411 |
+
torchvision_requirement = f"torchvision{operator}{torchvision_version}"
|
412 |
+
if isinstance(torchvision_version, str) and not torchvision_version.startswith("0.16"):
|
413 |
+
torchvision_requirement += f"+{hardware_suffix}"
|
414 |
+
|
415 |
+
# Return the install arguments.
|
416 |
+
install_args += [
|
417 |
+
"--extra-index-url",
|
418 |
+
# "--index-url",
|
419 |
+
index_url,
|
420 |
+
torch_version,
|
421 |
+
torchvision_requirement,
|
422 |
+
]
|
423 |
+
elif platform.system() in ("macos", "Darwin"):
|
424 |
+
torch_version = str(requirement)
|
425 |
+
install_args += [torch_version]
|
426 |
+
else:
|
427 |
+
msg = f"Unsupported OS: {platform.system()}"
|
428 |
+
raise RuntimeError(msg)
|
429 |
+
|
430 |
+
return install_args
|
anomalib/cli/utils/openvino.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for OpenVINO parser."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from jsonargparse import ArgumentParser
|
9 |
+
|
10 |
+
from anomalib.utils.exceptions import try_import
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
if try_import("openvino"):
|
16 |
+
from openvino.tools.ovc.cli_parser import get_common_cli_parser
|
17 |
+
else:
|
18 |
+
get_common_cli_parser = None
|
19 |
+
|
20 |
+
|
21 |
+
def add_openvino_export_arguments(parser: ArgumentParser) -> None:
|
22 |
+
"""Add OpenVINO arguments to parser under --mo key."""
|
23 |
+
if get_common_cli_parser is not None:
|
24 |
+
group = parser.add_argument_group("OpenVINO Model Optimizer arguments (optional)")
|
25 |
+
ov_parser = get_common_cli_parser()
|
26 |
+
# remove redundant keys from mo keys
|
27 |
+
for arg in ov_parser._actions: # noqa: SLF001
|
28 |
+
if arg.dest in ("help", "input_model", "output_dir"):
|
29 |
+
continue
|
30 |
+
group.add_argument(f"--ov_args.{arg.dest}", type=arg.type, default=arg.default, help=arg.help)
|
31 |
+
else:
|
32 |
+
logger.info("OpenVINO is possibly not installed in the environment. Skipping adding it to parser.")
|
anomalib/data/__init__.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib Datasets."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
import logging
|
9 |
+
from enum import Enum
|
10 |
+
from itertools import chain
|
11 |
+
|
12 |
+
from omegaconf import DictConfig, ListConfig
|
13 |
+
|
14 |
+
from anomalib.utils.config import to_tuple
|
15 |
+
|
16 |
+
from .base import AnomalibDataModule, AnomalibDataset
|
17 |
+
from .depth import DepthDataFormat, Folder3D, MVTec3D
|
18 |
+
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa
|
19 |
+
from .predict import PredictDataset
|
20 |
+
from .utils import LabelName
|
21 |
+
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
DataFormat = Enum( # type: ignore[misc]
|
27 |
+
"DataFormat",
|
28 |
+
{i.name: i.value for i in chain(DepthDataFormat, ImageDataFormat, VideoDataFormat)},
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
|
33 |
+
"""Get Anomaly Datamodule.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
config (DictConfig | ListConfig): Configuration of the anomaly model.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
PyTorch Lightning DataModule
|
40 |
+
"""
|
41 |
+
logger.info("Loading the datamodule")
|
42 |
+
|
43 |
+
module = importlib.import_module(".".join(config.data.class_path.split(".")[:-1]))
|
44 |
+
dataclass = getattr(module, config.data.class_path.split(".")[-1])
|
45 |
+
init_args = {**config.data.get("init_args", {})} # get dict
|
46 |
+
if "image_size" in init_args:
|
47 |
+
init_args["image_size"] = to_tuple(init_args["image_size"])
|
48 |
+
|
49 |
+
return dataclass(**init_args)
|
50 |
+
|
51 |
+
|
52 |
+
__all__ = [
|
53 |
+
"AnomalibDataset",
|
54 |
+
"AnomalibDataModule",
|
55 |
+
"DepthDataFormat",
|
56 |
+
"ImageDataFormat",
|
57 |
+
"VideoDataFormat",
|
58 |
+
"get_datamodule",
|
59 |
+
"BTech",
|
60 |
+
"Folder",
|
61 |
+
"Folder3D",
|
62 |
+
"PredictDataset",
|
63 |
+
"Kolektor",
|
64 |
+
"MVTec",
|
65 |
+
"MVTec3D",
|
66 |
+
"MVTecLoco",
|
67 |
+
"Avenue",
|
68 |
+
"UCSDped",
|
69 |
+
"ShanghaiTech",
|
70 |
+
"Visa",
|
71 |
+
"LabelName",
|
72 |
+
]
|
anomalib/data/base/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base classes for custom dataset and datamodules."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from .datamodule import AnomalibDataModule
|
8 |
+
from .dataset import AnomalibDataset
|
9 |
+
from .depth import AnomalibDepthDataset
|
10 |
+
from .video import AnomalibVideoDataModule, AnomalibVideoDataset
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"AnomalibDataset",
|
14 |
+
"AnomalibDataModule",
|
15 |
+
"AnomalibVideoDataset",
|
16 |
+
"AnomalibVideoDataModule",
|
17 |
+
"AnomalibDepthDataset",
|
18 |
+
]
|
anomalib/data/base/datamodule.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib datamodule base class."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from typing import TYPE_CHECKING, Any
|
10 |
+
|
11 |
+
from lightning.pytorch import LightningDataModule
|
12 |
+
from lightning.pytorch.trainer.states import TrainerFn
|
13 |
+
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
14 |
+
from torch.utils.data.dataloader import DataLoader, default_collate
|
15 |
+
from torchvision.transforms.v2 import Resize, Transform
|
16 |
+
|
17 |
+
from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label
|
18 |
+
from anomalib.data.utils.synthetic import SyntheticAnomalyDataset
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from pandas import DataFrame
|
22 |
+
|
23 |
+
from anomalib.data.base.dataset import AnomalibDataset
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def collate_fn(batch: list) -> dict[str, Any]:
|
29 |
+
"""Collate bounding boxes as lists.
|
30 |
+
|
31 |
+
Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists,
|
32 |
+
the `mask_path` is also collated as a list since each element in the batch could be unequal.
|
33 |
+
For all other entries, the default collate function is used.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
batch (List): list of items in the batch where len(batch) is equal to the batch size.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
dict[str, Any]: Dictionary containing the collated batch information.
|
40 |
+
"""
|
41 |
+
elem = batch[0] # sample an element from the batch to check the type.
|
42 |
+
out_dict = {}
|
43 |
+
if isinstance(elem, dict):
|
44 |
+
if "boxes" in elem:
|
45 |
+
# collate boxes as list
|
46 |
+
out_dict["boxes"] = [item.pop("boxes") for item in batch]
|
47 |
+
if "semantic_mask" in elem:
|
48 |
+
# semantic masks have a variable number of channels, so we collate them as a list
|
49 |
+
out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch]
|
50 |
+
if "mask_path" in elem and isinstance(elem["mask_path"], list):
|
51 |
+
# collate mask paths as list
|
52 |
+
out_dict["mask_path"] = [item.pop("mask_path") for item in batch]
|
53 |
+
# collate other data normally
|
54 |
+
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
|
55 |
+
return out_dict
|
56 |
+
return default_collate(batch)
|
57 |
+
|
58 |
+
|
59 |
+
class AnomalibDataModule(LightningDataModule, ABC):
|
60 |
+
"""Base Anomalib data module.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
train_batch_size (int): Batch size used by the train dataloader.
|
64 |
+
eval_batch_size (int): Batch size used by the val and test dataloaders.
|
65 |
+
num_workers (int): Number of workers used by the train, val and test dataloaders.
|
66 |
+
val_split_mode (ValSplitMode): Determines how the validation split is obtained.
|
67 |
+
Options: [none, same_as_test, from_test, synthetic]
|
68 |
+
val_split_ratio (float): Fraction of the train or test images held our for validation.
|
69 |
+
test_split_mode (Optional[TestSplitMode], optional): Determines how the test split is obtained.
|
70 |
+
Options: [none, from_dir, synthetic].
|
71 |
+
Defaults to ``None``.
|
72 |
+
test_split_ratio (float): Fraction of the train images held out for testing.
|
73 |
+
Defaults to ``None``.
|
74 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
75 |
+
Defaults to ``None``.
|
76 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
77 |
+
Defaults to ``None``.
|
78 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
79 |
+
Defaults to ``None``.
|
80 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
81 |
+
Defaults to ``None``.
|
82 |
+
seed (int | None, optional): Seed used during random subset splitting.
|
83 |
+
Defaults to ``None``.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
train_batch_size: int,
|
89 |
+
eval_batch_size: int,
|
90 |
+
num_workers: int,
|
91 |
+
val_split_mode: ValSplitMode | str,
|
92 |
+
val_split_ratio: float,
|
93 |
+
test_split_mode: TestSplitMode | str | None = None,
|
94 |
+
test_split_ratio: float | None = None,
|
95 |
+
image_size: tuple[int, int] | None = None,
|
96 |
+
transform: Transform | None = None,
|
97 |
+
train_transform: Transform | None = None,
|
98 |
+
eval_transform: Transform | None = None,
|
99 |
+
seed: int | None = None,
|
100 |
+
) -> None:
|
101 |
+
super().__init__()
|
102 |
+
self.train_batch_size = train_batch_size
|
103 |
+
self.eval_batch_size = eval_batch_size
|
104 |
+
self.num_workers = num_workers
|
105 |
+
self.test_split_mode = TestSplitMode(test_split_mode) if test_split_mode else TestSplitMode.NONE
|
106 |
+
self.test_split_ratio = test_split_ratio
|
107 |
+
self.val_split_mode = ValSplitMode(val_split_mode)
|
108 |
+
self.val_split_ratio = val_split_ratio
|
109 |
+
self.image_size = image_size
|
110 |
+
self.seed = seed
|
111 |
+
|
112 |
+
# set transforms
|
113 |
+
if bool(train_transform) != bool(eval_transform):
|
114 |
+
msg = "Only one of train_transform and eval_transform was specified. This is not recommended because \
|
115 |
+
it could lead to unexpected behaviour. Please ensure training and eval transforms have the same \
|
116 |
+
reshape and normalization characteristics."
|
117 |
+
logger.warning(msg)
|
118 |
+
self._train_transform = train_transform or transform
|
119 |
+
self._eval_transform = eval_transform or transform
|
120 |
+
|
121 |
+
self.train_data: AnomalibDataset
|
122 |
+
self.val_data: AnomalibDataset
|
123 |
+
self.test_data: AnomalibDataset
|
124 |
+
|
125 |
+
self._samples: DataFrame | None = None
|
126 |
+
self._category: str = ""
|
127 |
+
|
128 |
+
self._is_setup = False # flag to track if setup has been called from the trainer
|
129 |
+
|
130 |
+
@property
|
131 |
+
def name(self) -> str:
|
132 |
+
"""Name of the datamodule."""
|
133 |
+
return self.__class__.__name__
|
134 |
+
|
135 |
+
def setup(self, stage: str | None = None) -> None:
|
136 |
+
"""Set up train, validation and test data.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
stage: str | None: Train/Val/Test stages.
|
140 |
+
Defaults to ``None``.
|
141 |
+
"""
|
142 |
+
has_subset = any(hasattr(self, subset) for subset in ["train_data", "val_data", "test_data"])
|
143 |
+
if not has_subset or not self._is_setup:
|
144 |
+
self._setup(stage)
|
145 |
+
self._create_test_split()
|
146 |
+
self._create_val_split()
|
147 |
+
if isinstance(stage, TrainerFn):
|
148 |
+
# only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer
|
149 |
+
self._is_setup = True
|
150 |
+
|
151 |
+
@abstractmethod
|
152 |
+
def _setup(self, _stage: str | None = None) -> None:
|
153 |
+
"""Set up the datasets and perform dynamic subset splitting.
|
154 |
+
|
155 |
+
This method may be overridden in subclass for custom splitting behaviour.
|
156 |
+
|
157 |
+
Note:
|
158 |
+
The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule
|
159 |
+
subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset
|
160 |
+
splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and
|
161 |
+
the test set must therefore be created as early as the `fit` stage.
|
162 |
+
|
163 |
+
"""
|
164 |
+
raise NotImplementedError
|
165 |
+
|
166 |
+
@property
|
167 |
+
def category(self) -> str:
|
168 |
+
"""Get the category of the datamodule."""
|
169 |
+
return self._category
|
170 |
+
|
171 |
+
@category.setter
|
172 |
+
def category(self, category: str) -> None:
|
173 |
+
"""Set the category of the datamodule."""
|
174 |
+
self._category = category
|
175 |
+
|
176 |
+
def _create_test_split(self) -> None:
|
177 |
+
"""Obtain the test set based on the settings in the config."""
|
178 |
+
if self.test_data.has_normal:
|
179 |
+
# split the test data into normal and anomalous so these can be processed separately
|
180 |
+
normal_test_data, self.test_data = split_by_label(self.test_data)
|
181 |
+
elif self.test_split_mode != TestSplitMode.NONE:
|
182 |
+
# when the user did not provide any normal images for testing, we sample some from the training set,
|
183 |
+
# except when the user explicitly requested no test splitting.
|
184 |
+
logger.info(
|
185 |
+
"No normal test images found. Sampling from training set using a split ratio of %0.2f",
|
186 |
+
self.test_split_ratio,
|
187 |
+
)
|
188 |
+
if self.test_split_ratio is not None:
|
189 |
+
self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio, seed=self.seed)
|
190 |
+
|
191 |
+
if self.test_split_mode == TestSplitMode.FROM_DIR:
|
192 |
+
self.test_data += normal_test_data
|
193 |
+
elif self.test_split_mode == TestSplitMode.SYNTHETIC:
|
194 |
+
self.test_data = SyntheticAnomalyDataset.from_dataset(normal_test_data)
|
195 |
+
elif self.test_split_mode != TestSplitMode.NONE:
|
196 |
+
msg = f"Unsupported Test Split Mode: {self.test_split_mode}"
|
197 |
+
raise ValueError(msg)
|
198 |
+
|
199 |
+
def _create_val_split(self) -> None:
|
200 |
+
"""Obtain the validation set based on the settings in the config."""
|
201 |
+
if self.val_split_mode == ValSplitMode.FROM_TRAIN:
|
202 |
+
# randomly sampled from train set
|
203 |
+
self.train_data, self.val_data = random_split(
|
204 |
+
self.train_data,
|
205 |
+
self.val_split_ratio,
|
206 |
+
label_aware=True,
|
207 |
+
seed=self.seed,
|
208 |
+
)
|
209 |
+
elif self.val_split_mode == ValSplitMode.FROM_TEST:
|
210 |
+
# randomly sampled from test set
|
211 |
+
self.test_data, self.val_data = random_split(
|
212 |
+
self.test_data,
|
213 |
+
self.val_split_ratio,
|
214 |
+
label_aware=True,
|
215 |
+
seed=self.seed,
|
216 |
+
)
|
217 |
+
elif self.val_split_mode == ValSplitMode.SAME_AS_TEST:
|
218 |
+
# equal to test set
|
219 |
+
self.val_data = self.test_data
|
220 |
+
elif self.val_split_mode == ValSplitMode.SYNTHETIC:
|
221 |
+
# converted from random training sample
|
222 |
+
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)
|
223 |
+
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
|
224 |
+
elif self.val_split_mode == ValSplitMode.FROM_DIR:
|
225 |
+
# the val_data is prepared in subclass
|
226 |
+
assert hasattr(
|
227 |
+
self,
|
228 |
+
"val_data",
|
229 |
+
), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup."
|
230 |
+
elif self.val_split_mode != ValSplitMode.NONE:
|
231 |
+
msg = f"Unknown validation split mode: {self.val_split_mode}"
|
232 |
+
raise ValueError(msg)
|
233 |
+
|
234 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
235 |
+
"""Get train dataloader."""
|
236 |
+
return DataLoader(
|
237 |
+
dataset=self.train_data,
|
238 |
+
shuffle=True,
|
239 |
+
batch_size=self.train_batch_size,
|
240 |
+
num_workers=self.num_workers,
|
241 |
+
)
|
242 |
+
|
243 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
244 |
+
"""Get validation dataloader."""
|
245 |
+
return DataLoader(
|
246 |
+
dataset=self.val_data,
|
247 |
+
shuffle=False,
|
248 |
+
batch_size=self.eval_batch_size,
|
249 |
+
num_workers=self.num_workers,
|
250 |
+
collate_fn=collate_fn,
|
251 |
+
)
|
252 |
+
|
253 |
+
def test_dataloader(self) -> EVAL_DATALOADERS:
|
254 |
+
"""Get test dataloader."""
|
255 |
+
return DataLoader(
|
256 |
+
dataset=self.test_data,
|
257 |
+
shuffle=False,
|
258 |
+
batch_size=self.eval_batch_size,
|
259 |
+
num_workers=self.num_workers,
|
260 |
+
collate_fn=collate_fn,
|
261 |
+
)
|
262 |
+
|
263 |
+
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
264 |
+
"""Use the test dataloader for inference unless overridden."""
|
265 |
+
return self.test_dataloader()
|
266 |
+
|
267 |
+
@property
|
268 |
+
def transform(self) -> Transform:
|
269 |
+
"""Property that returns the user-specified transform for the datamodule, if any.
|
270 |
+
|
271 |
+
This property is accessed by the engine to set the transform for the model. The eval_transform takes precedence
|
272 |
+
over the train_transform, because the transform that we store in the model is the one that should be used during
|
273 |
+
inference.
|
274 |
+
"""
|
275 |
+
if self._eval_transform:
|
276 |
+
return self._eval_transform
|
277 |
+
return None
|
278 |
+
|
279 |
+
@property
|
280 |
+
def train_transform(self) -> Transform:
|
281 |
+
"""Get the transforms that will be passed to the train dataset.
|
282 |
+
|
283 |
+
If the train_transform is not set, the engine will request the transform from the model.
|
284 |
+
"""
|
285 |
+
if self._train_transform:
|
286 |
+
return self._train_transform
|
287 |
+
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
|
288 |
+
return self.trainer.model.transform
|
289 |
+
if self.image_size:
|
290 |
+
return Resize(self.image_size, antialias=True)
|
291 |
+
return None
|
292 |
+
|
293 |
+
@property
|
294 |
+
def eval_transform(self) -> Transform:
|
295 |
+
"""Get the transform that will be passed to the val/test/predict datasets.
|
296 |
+
|
297 |
+
If the eval_transform is not set, the engine will request the transform from the model.
|
298 |
+
"""
|
299 |
+
if self._eval_transform:
|
300 |
+
return self._eval_transform
|
301 |
+
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
|
302 |
+
return self.trainer.model.transform
|
303 |
+
if self.image_size:
|
304 |
+
return Resize(self.image_size, antialias=True)
|
305 |
+
return None
|
anomalib/data/base/dataset.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib dataset base class."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
from abc import ABC
|
9 |
+
from collections.abc import Sequence
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import pandas as pd
|
13 |
+
import torch
|
14 |
+
from pandas import DataFrame
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from torchvision.transforms.v2 import Transform
|
17 |
+
from torchvision.tv_tensors import Mask
|
18 |
+
|
19 |
+
from anomalib import TaskType
|
20 |
+
from anomalib.data.utils import LabelName, masks_to_boxes, read_image, read_mask
|
21 |
+
|
22 |
+
_EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"]
|
23 |
+
_EXPECTED_COLUMNS_SEGMENTATION = [*_EXPECTED_COLUMNS_CLASSIFICATION, "mask_path"]
|
24 |
+
_EXPECTED_COLUMNS_PERTASK = {
|
25 |
+
"classification": _EXPECTED_COLUMNS_CLASSIFICATION,
|
26 |
+
"segmentation": _EXPECTED_COLUMNS_SEGMENTATION,
|
27 |
+
"detection": _EXPECTED_COLUMNS_SEGMENTATION,
|
28 |
+
}
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class AnomalibDataset(Dataset, ABC):
|
34 |
+
"""Anomalib dataset.
|
35 |
+
|
36 |
+
The dataset is based on a dataframe that contains the information needed by the dataloader to load each of
|
37 |
+
the dataset items into memory.
|
38 |
+
|
39 |
+
The samples dataframe must be set from the subclass using the setter of the `samples` property.
|
40 |
+
|
41 |
+
The DataFrame must, at least, include the following columns:
|
42 |
+
- `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test').
|
43 |
+
- `image_path` (str): Path to the file system location where the image is stored.
|
44 |
+
- `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'.
|
45 |
+
- `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only).
|
46 |
+
Required if task is 'segmentation'.
|
47 |
+
|
48 |
+
Example DataFrame:
|
49 |
+
+---+-------------------+-----------+-------------+------------------+-------+
|
50 |
+
| | image_path | label | label_index | mask_path | split |
|
51 |
+
+---+-------------------+-----------+-------------+------------------+-------+
|
52 |
+
| 0 | path/to/image.png | anomalous | 1 | path/to/mask.png | train |
|
53 |
+
+---+-------------------+-----------+-------------+------------------+-------+
|
54 |
+
|
55 |
+
Note:
|
56 |
+
The example above is illustrative and may need to be adjusted based on the specific dataset structure.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
task (str): Task type, either 'classification' or 'segmentation'
|
60 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
61 |
+
Defaults to ``None``.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, task: TaskType | str, transform: Transform | None = None) -> None:
|
65 |
+
super().__init__()
|
66 |
+
self.task = TaskType(task)
|
67 |
+
self.transform = transform
|
68 |
+
self._samples: DataFrame | None = None
|
69 |
+
self._category: str | None = None
|
70 |
+
|
71 |
+
@property
|
72 |
+
def name(self) -> str:
|
73 |
+
"""Name of the dataset."""
|
74 |
+
class_name = self.__class__.__name__
|
75 |
+
|
76 |
+
# Remove the `_dataset` suffix from the class name
|
77 |
+
if class_name.endswith("Dataset"):
|
78 |
+
class_name = class_name[:-7]
|
79 |
+
|
80 |
+
return class_name
|
81 |
+
|
82 |
+
def __len__(self) -> int:
|
83 |
+
"""Get length of the dataset."""
|
84 |
+
return len(self.samples)
|
85 |
+
|
86 |
+
def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset":
|
87 |
+
"""Subsamples the dataset at the provided indices.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
indices (Sequence[int]): Indices at which the dataset is to be subsampled.
|
91 |
+
inplace (bool): When true, the subsampling will be performed on the instance itself.
|
92 |
+
Defaults to ``False``.
|
93 |
+
"""
|
94 |
+
if len(set(indices)) != len(indices):
|
95 |
+
msg = "No duplicates allowed in indices."
|
96 |
+
raise ValueError(msg)
|
97 |
+
dataset = self if inplace else copy.deepcopy(self)
|
98 |
+
dataset.samples = self.samples.iloc[indices].reset_index(drop=True)
|
99 |
+
return dataset
|
100 |
+
|
101 |
+
@property
|
102 |
+
def samples(self) -> DataFrame:
|
103 |
+
"""Get the samples dataframe."""
|
104 |
+
if self._samples is None:
|
105 |
+
msg = (
|
106 |
+
"Dataset does not have a samples dataframe. Ensure that a dataframe has been assigned to "
|
107 |
+
"`dataset.samples`."
|
108 |
+
)
|
109 |
+
raise RuntimeError(msg)
|
110 |
+
return self._samples
|
111 |
+
|
112 |
+
@samples.setter
|
113 |
+
def samples(self, samples: DataFrame) -> None:
|
114 |
+
"""Overwrite the samples with a new dataframe.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
samples (DataFrame): DataFrame with new samples.
|
118 |
+
"""
|
119 |
+
# validate the passed samples by checking the
|
120 |
+
if not isinstance(samples, DataFrame):
|
121 |
+
msg = f"samples must be a pandas.DataFrame, found {type(samples)}"
|
122 |
+
raise TypeError(msg)
|
123 |
+
|
124 |
+
expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task]
|
125 |
+
if not all(col in samples.columns for col in expected_columns):
|
126 |
+
msg = f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
|
127 |
+
raise ValueError(msg)
|
128 |
+
|
129 |
+
if not samples["image_path"].apply(lambda p: Path(p).exists()).all():
|
130 |
+
msg = "missing file path(s) in samples"
|
131 |
+
raise FileNotFoundError(msg)
|
132 |
+
|
133 |
+
self._samples = samples.sort_values(by="image_path", ignore_index=True)
|
134 |
+
|
135 |
+
@property
|
136 |
+
def category(self) -> str | None:
|
137 |
+
"""Get the category of the dataset."""
|
138 |
+
return self._category
|
139 |
+
|
140 |
+
@category.setter
|
141 |
+
def category(self, category: str) -> None:
|
142 |
+
"""Set the category of the dataset."""
|
143 |
+
self._category = category
|
144 |
+
|
145 |
+
@property
|
146 |
+
def has_normal(self) -> bool:
|
147 |
+
"""Check if the dataset contains any normal samples."""
|
148 |
+
return LabelName.NORMAL in list(self.samples.label_index)
|
149 |
+
|
150 |
+
@property
|
151 |
+
def has_anomalous(self) -> bool:
|
152 |
+
"""Check if the dataset contains any anomalous samples."""
|
153 |
+
return LabelName.ABNORMAL in list(self.samples.label_index)
|
154 |
+
|
155 |
+
def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
|
156 |
+
"""Get dataset item for the index ``index``.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
index (int): Index to get the item.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
|
163 |
+
target path, image tensor, label and transformed bounding box.
|
164 |
+
"""
|
165 |
+
image_path = self.samples.iloc[index].image_path
|
166 |
+
mask_path = self.samples.iloc[index].mask_path
|
167 |
+
label_index = self.samples.iloc[index].label_index
|
168 |
+
|
169 |
+
image = read_image(image_path, as_tensor=True)
|
170 |
+
item = {"image_path": image_path, "label": label_index}
|
171 |
+
|
172 |
+
if self.task == TaskType.CLASSIFICATION:
|
173 |
+
item["image"] = self.transform(image) if self.transform else image
|
174 |
+
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
|
175 |
+
# Only Anomalous (1) images have masks in anomaly datasets
|
176 |
+
# Therefore, create empty mask for Normal (0) images.
|
177 |
+
mask = (
|
178 |
+
Mask(torch.zeros(image.shape[-2:])).to(torch.uint8)
|
179 |
+
if label_index == LabelName.NORMAL
|
180 |
+
else read_mask(mask_path, as_tensor=True)
|
181 |
+
)
|
182 |
+
item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask)
|
183 |
+
|
184 |
+
if self.task == TaskType.DETECTION:
|
185 |
+
# create boxes from masks for detection task
|
186 |
+
boxes, _ = masks_to_boxes(item["mask"])
|
187 |
+
item["boxes"] = boxes[0]
|
188 |
+
else:
|
189 |
+
msg = f"Unknown task type: {self.task}"
|
190 |
+
raise ValueError(msg)
|
191 |
+
|
192 |
+
return item
|
193 |
+
|
194 |
+
def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
|
195 |
+
"""Concatenate this dataset with another dataset.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
other_dataset (AnomalibDataset): Dataset to concatenate with.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
AnomalibDataset: Concatenated dataset.
|
202 |
+
"""
|
203 |
+
if not isinstance(other_dataset, self.__class__):
|
204 |
+
msg = "Cannot concatenate datasets that are not of the same type."
|
205 |
+
raise TypeError(msg)
|
206 |
+
dataset = copy.deepcopy(self)
|
207 |
+
dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True)
|
208 |
+
return dataset
|
anomalib/data/base/depth.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base Depth Dataset."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from abc import ABC
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms.functional import to_tensor
|
11 |
+
from torchvision.transforms.v2 import Transform
|
12 |
+
from torchvision.tv_tensors import Mask
|
13 |
+
|
14 |
+
from anomalib import TaskType
|
15 |
+
from anomalib.data.base.dataset import AnomalibDataset
|
16 |
+
from anomalib.data.utils import LabelName, masks_to_boxes, read_depth_image
|
17 |
+
|
18 |
+
|
19 |
+
class AnomalibDepthDataset(AnomalibDataset, ABC):
|
20 |
+
"""Base depth anomalib dataset class.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
task (str): Task type, either 'classification' or 'segmentation'
|
24 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
25 |
+
Defaults to ``None``.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, task: TaskType, transform: Transform | None = None) -> None:
|
29 |
+
super().__init__(task, transform)
|
30 |
+
|
31 |
+
self.transform = transform
|
32 |
+
|
33 |
+
def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
|
34 |
+
"""Return rgb image, depth image and mask.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
index (int): Index of the item to be returned.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
dict[str, str | torch.Tensor]: Dictionary containing the image, depth image and mask.
|
41 |
+
"""
|
42 |
+
image_path = self.samples.iloc[index].image_path
|
43 |
+
mask_path = self.samples.iloc[index].mask_path
|
44 |
+
label_index = self.samples.iloc[index].label_index
|
45 |
+
depth_path = self.samples.iloc[index].depth_path
|
46 |
+
|
47 |
+
image = to_tensor(Image.open(image_path))
|
48 |
+
depth_image = to_tensor(read_depth_image(depth_path))
|
49 |
+
item = {"image_path": image_path, "depth_path": depth_path, "label": label_index}
|
50 |
+
|
51 |
+
if self.task == TaskType.CLASSIFICATION:
|
52 |
+
item["image"], item["depth_image"] = (
|
53 |
+
self.transform(image, depth_image) if self.transform else (image, depth_image)
|
54 |
+
)
|
55 |
+
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
|
56 |
+
# Only Anomalous (1) images have masks in anomaly datasets
|
57 |
+
# Therefore, create empty mask for Normal (0) images.
|
58 |
+
mask = (
|
59 |
+
Mask(torch.zeros(image.shape[-2:]))
|
60 |
+
if label_index == LabelName.NORMAL
|
61 |
+
else Mask(to_tensor(Image.open(mask_path)).squeeze())
|
62 |
+
)
|
63 |
+
item["image"], item["depth_image"], item["mask"] = (
|
64 |
+
self.transform(image, depth_image, mask) if self.transform else (image, depth_image, mask)
|
65 |
+
)
|
66 |
+
item["mask_path"] = mask_path
|
67 |
+
|
68 |
+
if self.task == TaskType.DETECTION:
|
69 |
+
# create boxes from masks for detection task
|
70 |
+
boxes, _ = masks_to_boxes(item["mask"])
|
71 |
+
item["boxes"] = boxes[0]
|
72 |
+
else:
|
73 |
+
msg = f"Unknown task type: {self.task}"
|
74 |
+
raise ValueError(msg)
|
75 |
+
|
76 |
+
return item
|
anomalib/data/base/video.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base Video Dataset."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from abc import ABC
|
7 |
+
from enum import Enum
|
8 |
+
from typing import TYPE_CHECKING, Any
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from pandas import DataFrame
|
12 |
+
from torchvision.transforms.v2 import Transform
|
13 |
+
from torchvision.transforms.v2.functional import to_dtype_video
|
14 |
+
from torchvision.tv_tensors import Mask
|
15 |
+
|
16 |
+
from anomalib import TaskType
|
17 |
+
from anomalib.data.base.datamodule import AnomalibDataModule
|
18 |
+
from anomalib.data.base.dataset import AnomalibDataset
|
19 |
+
from anomalib.data.utils import ValSplitMode, masks_to_boxes
|
20 |
+
from anomalib.data.utils.video import ClipsIndexer
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from collections.abc import Callable
|
24 |
+
|
25 |
+
|
26 |
+
class VideoTargetFrame(str, Enum):
|
27 |
+
"""Target frame for a video-clip.
|
28 |
+
|
29 |
+
Used in multi-frame models to determine which frame's ground truth information will be used.
|
30 |
+
"""
|
31 |
+
|
32 |
+
FIRST = "first"
|
33 |
+
LAST = "last"
|
34 |
+
MID = "mid"
|
35 |
+
ALL = "all"
|
36 |
+
|
37 |
+
|
38 |
+
class AnomalibVideoDataset(AnomalibDataset, ABC):
|
39 |
+
"""Base video anomalib dataset class.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
task (str): Task type, either 'classification' or 'segmentation'
|
43 |
+
clip_length_in_frames (int): Number of video frames in each clip.
|
44 |
+
frames_between_clips (int): Number of frames between each consecutive video clip.
|
45 |
+
transform (Transform, optional): Transforms that should be applied to the input clips.
|
46 |
+
Defaults to ``None``.
|
47 |
+
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval.
|
48 |
+
Defaults to ``VideoTargetFrame.LAST``.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
task: TaskType,
|
54 |
+
clip_length_in_frames: int,
|
55 |
+
frames_between_clips: int,
|
56 |
+
transform: Transform | None = None,
|
57 |
+
target_frame: VideoTargetFrame = VideoTargetFrame.LAST,
|
58 |
+
) -> None:
|
59 |
+
super().__init__(task, transform)
|
60 |
+
|
61 |
+
self.clip_length_in_frames = clip_length_in_frames
|
62 |
+
self.frames_between_clips = frames_between_clips
|
63 |
+
self.transform = transform
|
64 |
+
|
65 |
+
self.indexer: ClipsIndexer | None = None
|
66 |
+
self.indexer_cls: Callable | None = None
|
67 |
+
|
68 |
+
self.target_frame = target_frame
|
69 |
+
|
70 |
+
def __len__(self) -> int:
|
71 |
+
"""Get length of the dataset."""
|
72 |
+
if not isinstance(self.indexer, ClipsIndexer):
|
73 |
+
msg = "self.indexer must be an instance of ClipsIndexer."
|
74 |
+
raise TypeError(msg)
|
75 |
+
return self.indexer.num_clips()
|
76 |
+
|
77 |
+
@property
|
78 |
+
def samples(self) -> DataFrame:
|
79 |
+
"""Get the samples dataframe."""
|
80 |
+
return super().samples
|
81 |
+
|
82 |
+
@samples.setter
|
83 |
+
def samples(self, samples: DataFrame) -> None:
|
84 |
+
"""Overwrite samples and re-index subvideos.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
samples (DataFrame): DataFrame with new samples.
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
ValueError: If the indexer class is not set.
|
91 |
+
"""
|
92 |
+
super(AnomalibVideoDataset, self.__class__).samples.fset(self, samples) # type: ignore[attr-defined]
|
93 |
+
self._setup_clips()
|
94 |
+
|
95 |
+
def _setup_clips(self) -> None:
|
96 |
+
"""Compute the video and frame indices of the subvideos.
|
97 |
+
|
98 |
+
Should be called after each change to self._samples
|
99 |
+
"""
|
100 |
+
if not callable(self.indexer_cls):
|
101 |
+
msg = "self.indexer_cls must be callable."
|
102 |
+
raise TypeError(msg)
|
103 |
+
self.indexer = self.indexer_cls( # pylint: disable=not-callable
|
104 |
+
video_paths=list(self.samples.image_path),
|
105 |
+
mask_paths=list(self.samples.mask_path),
|
106 |
+
clip_length_in_frames=self.clip_length_in_frames,
|
107 |
+
frames_between_clips=self.frames_between_clips,
|
108 |
+
)
|
109 |
+
|
110 |
+
def _select_targets(self, item: dict[str, Any]) -> dict[str, Any]:
|
111 |
+
"""Select the target frame from the clip.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
item (dict[str, Any]): Item containing the clip information.
|
115 |
+
|
116 |
+
Raises:
|
117 |
+
ValueError: If the target frame is not one of the supported options.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
dict[str, Any]: Selected item from the clip.
|
121 |
+
"""
|
122 |
+
if self.target_frame == VideoTargetFrame.FIRST:
|
123 |
+
idx = 0
|
124 |
+
elif self.target_frame == VideoTargetFrame.LAST:
|
125 |
+
idx = -1
|
126 |
+
elif self.target_frame == VideoTargetFrame.MID:
|
127 |
+
idx = int(self.clip_length_in_frames / 2)
|
128 |
+
else:
|
129 |
+
msg = f"Unknown video target frame: {self.target_frame}"
|
130 |
+
raise ValueError(msg)
|
131 |
+
|
132 |
+
if item.get("mask") is not None:
|
133 |
+
item["mask"] = item["mask"][idx, ...]
|
134 |
+
if item.get("boxes") is not None:
|
135 |
+
item["boxes"] = item["boxes"][idx]
|
136 |
+
if item.get("label") is not None:
|
137 |
+
item["label"] = item["label"][idx]
|
138 |
+
if item.get("original_image") is not None:
|
139 |
+
item["original_image"] = item["original_image"][idx]
|
140 |
+
if item.get("frames") is not None:
|
141 |
+
item["frames"] = item["frames"][idx]
|
142 |
+
return item
|
143 |
+
|
144 |
+
def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
|
145 |
+
"""Get the dataset item for the index ``index``.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
index (int): Index of the item to be returned.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
dict[str, str | torch.Tensor]: Dictionary containing the mask, clip and file system information.
|
152 |
+
"""
|
153 |
+
if not isinstance(self.indexer, ClipsIndexer):
|
154 |
+
msg = "self.indexer must be an instance of ClipsIndexer."
|
155 |
+
raise TypeError(msg)
|
156 |
+
item = self.indexer.get_item(index)
|
157 |
+
item["image"] = to_dtype_video(video=item["image"], scale=True)
|
158 |
+
# include the untransformed image for visualization
|
159 |
+
item["original_image"] = item["image"].to(torch.uint8)
|
160 |
+
|
161 |
+
# apply transforms
|
162 |
+
if item.get("mask") is not None:
|
163 |
+
if self.transform:
|
164 |
+
item["image"], item["mask"] = self.transform(item["image"], Mask(item["mask"]))
|
165 |
+
item["label"] = torch.Tensor([1 in frame for frame in item["mask"]]).int().squeeze(0)
|
166 |
+
if self.task == TaskType.DETECTION:
|
167 |
+
item["boxes"], _ = masks_to_boxes(item["mask"])
|
168 |
+
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
|
169 |
+
elif self.transform:
|
170 |
+
item["image"] = self.transform(item["image"])
|
171 |
+
|
172 |
+
# squeeze temporal dimensions in case clip length is 1
|
173 |
+
item["image"] = item["image"].squeeze(0)
|
174 |
+
|
175 |
+
# include only target frame in gt
|
176 |
+
if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL:
|
177 |
+
item = self._select_targets(item)
|
178 |
+
|
179 |
+
if item["mask"] is None:
|
180 |
+
item.pop("mask")
|
181 |
+
|
182 |
+
return item
|
183 |
+
|
184 |
+
|
185 |
+
class AnomalibVideoDataModule(AnomalibDataModule):
|
186 |
+
"""Base class for video data modules."""
|
187 |
+
|
188 |
+
def _create_test_split(self) -> None:
|
189 |
+
"""Video datamodules do not support dynamic assignment of the test split."""
|
190 |
+
|
191 |
+
def _setup(self, _stage: str | None = None) -> None:
|
192 |
+
"""Set up the datasets and perform dynamic subset splitting.
|
193 |
+
|
194 |
+
This method may be overridden in subclass for custom splitting behaviour.
|
195 |
+
|
196 |
+
Video datamodules are not compatible with synthetic anomaly generation.
|
197 |
+
"""
|
198 |
+
if self.train_data is None:
|
199 |
+
msg = "self.train_data cannot be None."
|
200 |
+
raise ValueError(msg)
|
201 |
+
|
202 |
+
if self.test_data is None:
|
203 |
+
msg = "self.test_data cannot be None."
|
204 |
+
raise ValueError(msg)
|
205 |
+
|
206 |
+
self.train_data.setup()
|
207 |
+
self.test_data.setup()
|
208 |
+
|
209 |
+
if self.val_split_mode == ValSplitMode.SYNTHETIC:
|
210 |
+
msg = f"Val split mode {self.test_split_mode} not supported for video datasets."
|
211 |
+
raise ValueError(msg)
|
212 |
+
|
213 |
+
self._create_val_split()
|
anomalib/data/depth/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib Depth Datasets."""
|
2 |
+
|
3 |
+
# Copyright (C) 2023 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from .folder_3d import Folder3D
|
10 |
+
from .mvtec_3d import MVTec3D
|
11 |
+
|
12 |
+
|
13 |
+
class DepthDataFormat(str, Enum):
|
14 |
+
"""Supported Depth Dataset Types."""
|
15 |
+
|
16 |
+
MVTEC_3D = "mvtec_3d"
|
17 |
+
FOLDER_3D = "folder_3d"
|
18 |
+
|
19 |
+
|
20 |
+
__all__ = ["Folder3D", "MVTec3D"]
|
anomalib/data/depth/folder_3d.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Folder Dataset.
|
2 |
+
|
3 |
+
This script creates a custom dataset from a folder.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Copyright (C) 2022 Intel Corporation
|
7 |
+
# SPDX-License-Identifier: Apache-2.0
|
8 |
+
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from pandas import DataFrame, isna
|
13 |
+
from torchvision.transforms.v2 import Transform
|
14 |
+
|
15 |
+
from anomalib import TaskType
|
16 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
|
17 |
+
from anomalib.data.errors import MisMatchError
|
18 |
+
from anomalib.data.utils import (
|
19 |
+
DirType,
|
20 |
+
LabelName,
|
21 |
+
Split,
|
22 |
+
TestSplitMode,
|
23 |
+
ValSplitMode,
|
24 |
+
)
|
25 |
+
from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path
|
26 |
+
|
27 |
+
|
28 |
+
def make_folder3d_dataset( # noqa: C901
|
29 |
+
normal_dir: str | Path,
|
30 |
+
root: str | Path | None = None,
|
31 |
+
abnormal_dir: str | Path | None = None,
|
32 |
+
normal_test_dir: str | Path | None = None,
|
33 |
+
mask_dir: str | Path | None = None,
|
34 |
+
normal_depth_dir: str | Path | None = None,
|
35 |
+
abnormal_depth_dir: str | Path | None = None,
|
36 |
+
normal_test_depth_dir: str | Path | None = None,
|
37 |
+
split: str | Split | None = None,
|
38 |
+
extensions: tuple[str, ...] | None = None,
|
39 |
+
) -> DataFrame:
|
40 |
+
"""Make Folder Dataset.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
normal_dir (str | Path): Path to the directory containing normal images.
|
44 |
+
root (str | Path | None): Path to the root directory of the dataset.
|
45 |
+
Defaults to ``None``.
|
46 |
+
abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images.
|
47 |
+
Defaults to ``None``.
|
48 |
+
normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test
|
49 |
+
dataset. Normal test images will be a split of `normal_dir` if `None`.
|
50 |
+
Defaults to ``None``.
|
51 |
+
mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations.
|
52 |
+
Defaults to ``None``.
|
53 |
+
normal_depth_dir (str | Path | None, optional): Path to the directory containing
|
54 |
+
normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
|
55 |
+
Defaults to ``None``.
|
56 |
+
abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for
|
57 |
+
the test dataset.
|
58 |
+
Defaults to ``None``.
|
59 |
+
normal_test_depth_dir (str | Path | None, optional): Path to the directory containing normal depth images for
|
60 |
+
the test dataset. Normal test images will be a split of `normal_dir` if `None`.
|
61 |
+
Defaults to ``None``.
|
62 |
+
split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST).
|
63 |
+
Defaults to ``None``.
|
64 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
|
65 |
+
Defaults to ``None``.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
DataFrame: an output dataframe containing samples for the requested split (ie., train or test)
|
69 |
+
"""
|
70 |
+
normal_dir = validate_and_resolve_path(normal_dir, root)
|
71 |
+
abnormal_dir = validate_and_resolve_path(abnormal_dir, root) if abnormal_dir else None
|
72 |
+
normal_test_dir = validate_and_resolve_path(normal_test_dir, root) if normal_test_dir else None
|
73 |
+
mask_dir = validate_and_resolve_path(mask_dir, root) if mask_dir else None
|
74 |
+
normal_depth_dir = validate_and_resolve_path(normal_depth_dir, root) if normal_depth_dir else None
|
75 |
+
abnormal_depth_dir = validate_and_resolve_path(abnormal_depth_dir, root) if abnormal_depth_dir else None
|
76 |
+
normal_test_depth_dir = validate_and_resolve_path(normal_test_depth_dir, root) if normal_test_depth_dir else None
|
77 |
+
|
78 |
+
if not normal_dir.is_dir():
|
79 |
+
msg = "A folder location must be provided in normal_dir."
|
80 |
+
raise ValueError(msg)
|
81 |
+
|
82 |
+
filenames = []
|
83 |
+
labels = []
|
84 |
+
dirs = {DirType.NORMAL: normal_dir}
|
85 |
+
|
86 |
+
if abnormal_dir:
|
87 |
+
dirs[DirType.ABNORMAL] = abnormal_dir
|
88 |
+
|
89 |
+
if normal_test_dir:
|
90 |
+
dirs[DirType.NORMAL_TEST] = normal_test_dir
|
91 |
+
|
92 |
+
if normal_depth_dir:
|
93 |
+
dirs[DirType.NORMAL_DEPTH] = normal_depth_dir
|
94 |
+
|
95 |
+
if abnormal_depth_dir:
|
96 |
+
dirs[DirType.ABNORMAL_DEPTH] = abnormal_depth_dir
|
97 |
+
|
98 |
+
if normal_test_depth_dir:
|
99 |
+
dirs[DirType.NORMAL_TEST_DEPTH] = normal_test_depth_dir
|
100 |
+
|
101 |
+
if mask_dir:
|
102 |
+
dirs[DirType.MASK] = mask_dir
|
103 |
+
|
104 |
+
for dir_type, path in dirs.items():
|
105 |
+
filename, label = _prepare_files_labels(path, dir_type, extensions)
|
106 |
+
filenames += filename
|
107 |
+
labels += label
|
108 |
+
|
109 |
+
samples = DataFrame({"image_path": filenames, "label": labels})
|
110 |
+
samples = samples.sort_values(by="image_path", ignore_index=True)
|
111 |
+
|
112 |
+
# Create label index for normal (0) and abnormal (1) images.
|
113 |
+
samples.loc[
|
114 |
+
(samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST),
|
115 |
+
"label_index",
|
116 |
+
] = LabelName.NORMAL
|
117 |
+
samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
|
118 |
+
samples.label_index = samples.label_index.astype("Int64")
|
119 |
+
|
120 |
+
# If a path to mask is provided, add it to the sample dataframe.
|
121 |
+
if normal_depth_dir:
|
122 |
+
samples.loc[samples.label == DirType.NORMAL, "depth_path"] = samples.loc[
|
123 |
+
samples.label == DirType.NORMAL_DEPTH
|
124 |
+
].image_path.to_numpy()
|
125 |
+
samples.loc[samples.label == DirType.ABNORMAL, "depth_path"] = samples.loc[
|
126 |
+
samples.label == DirType.ABNORMAL_DEPTH
|
127 |
+
].image_path.to_numpy()
|
128 |
+
|
129 |
+
if normal_test_dir:
|
130 |
+
samples.loc[samples.label == DirType.NORMAL_TEST, "depth_path"] = samples.loc[
|
131 |
+
samples.label == DirType.NORMAL_TEST_DEPTH
|
132 |
+
].image_path.to_numpy()
|
133 |
+
|
134 |
+
# make sure every rgb image has a corresponding depth image and that the file exists
|
135 |
+
mismatch = (
|
136 |
+
samples.loc[samples.label_index == LabelName.ABNORMAL]
|
137 |
+
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
|
138 |
+
.all()
|
139 |
+
)
|
140 |
+
if not mismatch:
|
141 |
+
msg = """Mismatch between anomalous images and depth images. Make sure the mask files
|
142 |
+
in 'xyz' folder follow the same naming convention as the anomalous images in the dataset
|
143 |
+
(e.g. image: '000.png', depth: '000.tiff')."""
|
144 |
+
raise MisMatchError(msg)
|
145 |
+
|
146 |
+
missing_depth_files = samples.depth_path.apply(
|
147 |
+
lambda x: Path(x).exists() if not isna(x) else True,
|
148 |
+
).all()
|
149 |
+
if not missing_depth_files:
|
150 |
+
msg = "Missing depth image files."
|
151 |
+
raise FileNotFoundError(msg)
|
152 |
+
|
153 |
+
samples = samples.astype({"depth_path": "str"})
|
154 |
+
|
155 |
+
# If a path to mask is provided, add it to the sample dataframe.
|
156 |
+
if mask_dir and abnormal_dir:
|
157 |
+
samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
|
158 |
+
samples.label == DirType.MASK
|
159 |
+
].image_path.to_numpy()
|
160 |
+
samples["mask_path"] = samples["mask_path"].fillna("")
|
161 |
+
samples = samples.astype({"mask_path": "str"})
|
162 |
+
|
163 |
+
# make sure all the files exist
|
164 |
+
if not samples.mask_path.apply(
|
165 |
+
lambda x: Path(x).exists() if x != "" else True,
|
166 |
+
).all():
|
167 |
+
msg = f"Missing mask files. mask_dir={mask_dir}"
|
168 |
+
raise FileNotFoundError(msg)
|
169 |
+
else:
|
170 |
+
samples["mask_path"] = ""
|
171 |
+
|
172 |
+
# remove all the rows with temporal image samples that have already been assigned
|
173 |
+
samples = samples.loc[
|
174 |
+
(samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
|
175 |
+
]
|
176 |
+
|
177 |
+
# Ensure the pathlib objects are converted to str.
|
178 |
+
# This is because torch dataloader doesn't like pathlib.
|
179 |
+
samples = samples.astype({"image_path": "str"})
|
180 |
+
|
181 |
+
# Create train/test split.
|
182 |
+
# By default, all the normal samples are assigned as train.
|
183 |
+
# and all the abnormal samples are test.
|
184 |
+
samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
|
185 |
+
samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST
|
186 |
+
|
187 |
+
# Get the data frame for the split.
|
188 |
+
if split:
|
189 |
+
samples = samples[samples.split == split]
|
190 |
+
samples = samples.reset_index(drop=True)
|
191 |
+
|
192 |
+
return samples
|
193 |
+
|
194 |
+
|
195 |
+
class Folder3DDataset(AnomalibDepthDataset):
|
196 |
+
"""Folder dataset.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
name (str): Name of the dataset.
|
200 |
+
task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
|
201 |
+
transform (Transform): Transforms that should be applied to the input images.
|
202 |
+
normal_dir (str | Path): Path to the directory containing normal images.
|
203 |
+
root (str | Path | None): Root folder of the dataset.
|
204 |
+
Defaults to ``None``.
|
205 |
+
abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images.
|
206 |
+
Defaults to ``None``.
|
207 |
+
normal_test_dir (str | Path | None, optional): Path to the directory containing
|
208 |
+
normal images for the test dataset.
|
209 |
+
Defaults to ``None``.
|
210 |
+
mask_dir (str | Path | None, optional): Path to the directory containing
|
211 |
+
the mask annotations.
|
212 |
+
Defaults to ``None``.
|
213 |
+
normal_depth_dir (str | Path | None, optional): Path to the directory containing
|
214 |
+
normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
|
215 |
+
Defaults to ``None``.
|
216 |
+
abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for
|
217 |
+
the test dataset.
|
218 |
+
Defaults to ``None``.
|
219 |
+
normal_test_depth_dir (str | Path | None, optional): Path to the directory containing
|
220 |
+
normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`.
|
221 |
+
Defaults to ``None``.
|
222 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
223 |
+
Defaults to ``None``.
|
224 |
+
split (str | Split | None): Fixed subset split that follows from folder structure on file system.
|
225 |
+
Choose from [Split.FULL, Split.TRAIN, Split.TEST]
|
226 |
+
Defaults to ``None``.
|
227 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
|
228 |
+
Defaults to ``None``.
|
229 |
+
|
230 |
+
Raises:
|
231 |
+
ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is
|
232 |
+
provided, `task` should be set to `segmentation`.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
name: str,
|
238 |
+
task: TaskType,
|
239 |
+
normal_dir: str | Path,
|
240 |
+
root: str | Path | None = None,
|
241 |
+
abnormal_dir: str | Path | None = None,
|
242 |
+
normal_test_dir: str | Path | None = None,
|
243 |
+
mask_dir: str | Path | None = None,
|
244 |
+
normal_depth_dir: str | Path | None = None,
|
245 |
+
abnormal_depth_dir: str | Path | None = None,
|
246 |
+
normal_test_depth_dir: str | Path | None = None,
|
247 |
+
transform: Transform | None = None,
|
248 |
+
split: str | Split | None = None,
|
249 |
+
extensions: tuple[str, ...] | None = None,
|
250 |
+
) -> None:
|
251 |
+
super().__init__(task, transform)
|
252 |
+
|
253 |
+
self._name = name
|
254 |
+
self.split = split
|
255 |
+
self.root = root
|
256 |
+
self.normal_dir = normal_dir
|
257 |
+
self.abnormal_dir = abnormal_dir
|
258 |
+
self.normal_test_dir = normal_test_dir
|
259 |
+
self.mask_dir = mask_dir
|
260 |
+
self.normal_depth_dir = normal_depth_dir
|
261 |
+
self.abnormal_depth_dir = abnormal_depth_dir
|
262 |
+
self.normal_test_depth_dir = normal_test_depth_dir
|
263 |
+
self.extensions = extensions
|
264 |
+
|
265 |
+
self.samples = make_folder3d_dataset(
|
266 |
+
root=self.root,
|
267 |
+
normal_dir=self.normal_dir,
|
268 |
+
abnormal_dir=self.abnormal_dir,
|
269 |
+
normal_test_dir=self.normal_test_dir,
|
270 |
+
mask_dir=self.mask_dir,
|
271 |
+
normal_depth_dir=self.normal_depth_dir,
|
272 |
+
abnormal_depth_dir=self.abnormal_depth_dir,
|
273 |
+
normal_test_depth_dir=self.normal_test_depth_dir,
|
274 |
+
split=self.split,
|
275 |
+
extensions=self.extensions,
|
276 |
+
)
|
277 |
+
|
278 |
+
@property
|
279 |
+
def name(self) -> str:
|
280 |
+
"""Name of the dataset.
|
281 |
+
|
282 |
+
Folder3D dataset overrides the name property to provide a custom name.
|
283 |
+
"""
|
284 |
+
return self._name
|
285 |
+
|
286 |
+
|
287 |
+
class Folder3D(AnomalibDataModule):
|
288 |
+
"""Folder DataModule.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
|
292 |
+
normal_dir (str | Path): Name of the directory containing normal images.
|
293 |
+
root (str | Path | None): Path to the root folder containing normal and abnormal dirs.
|
294 |
+
Defaults to ``None``.
|
295 |
+
abnormal_dir (str | Path | None): Name of the directory containing abnormal images.
|
296 |
+
Defaults to ``abnormal``.
|
297 |
+
normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test
|
298 |
+
dataset.
|
299 |
+
Defaults to ``None``.
|
300 |
+
mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations.
|
301 |
+
Defaults to ``None``.
|
302 |
+
normal_depth_dir (str | Path | None, optional): Path to the directory containing
|
303 |
+
normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir`
|
304 |
+
abnormal_depth_dir (str | Path | None, optional): Path to the directory containing
|
305 |
+
abnormal depth images for the test dataset.
|
306 |
+
normal_test_depth_dir (str | Path | None, optional): Path to the directory containing
|
307 |
+
normal depth images for the test dataset. Normal test images will be a split of `normal_dir`
|
308 |
+
if `None`. Defaults to None.
|
309 |
+
normal_split_ratio (float, optional): Ratio to split normal training images and add to the
|
310 |
+
test set in case test set doesn't contain any normal images.
|
311 |
+
Defaults to 0.2.
|
312 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the
|
313 |
+
directory. Defaults to None.
|
314 |
+
train_batch_size (int, optional): Training batch size.
|
315 |
+
Defaults to ``32``.
|
316 |
+
eval_batch_size (int, optional): Test batch size.
|
317 |
+
Defaults to ``32``.
|
318 |
+
num_workers (int, optional): Number of workers.
|
319 |
+
Defaults to ``8``.
|
320 |
+
task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``.
|
321 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
322 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
323 |
+
Defaults to ``None``.
|
324 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
325 |
+
Defaults to ``None``.
|
326 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
327 |
+
Defaults to ``None``.
|
328 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
329 |
+
Defaults to ``None``.
|
330 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
331 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
332 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
333 |
+
Defaults to ``0.2``.
|
334 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
335 |
+
Defaults to ``ValSplitMode.FROM_TEST``.
|
336 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
337 |
+
Defaults to ``0.5``.
|
338 |
+
seed (int | None, optional): Seed used during random subset splitting.
|
339 |
+
Defaults to ``None``.
|
340 |
+
"""
|
341 |
+
|
342 |
+
def __init__(
|
343 |
+
self,
|
344 |
+
name: str,
|
345 |
+
normal_dir: str | Path,
|
346 |
+
root: str | Path,
|
347 |
+
abnormal_dir: str | Path | None = None,
|
348 |
+
normal_test_dir: str | Path | None = None,
|
349 |
+
mask_dir: str | Path | None = None,
|
350 |
+
normal_depth_dir: str | Path | None = None,
|
351 |
+
abnormal_depth_dir: str | Path | None = None,
|
352 |
+
normal_test_depth_dir: str | Path | None = None,
|
353 |
+
extensions: tuple[str] | None = None,
|
354 |
+
train_batch_size: int = 32,
|
355 |
+
eval_batch_size: int = 32,
|
356 |
+
num_workers: int = 8,
|
357 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
358 |
+
image_size: tuple[int, int] | None = None,
|
359 |
+
transform: Transform | None = None,
|
360 |
+
train_transform: Transform | None = None,
|
361 |
+
eval_transform: Transform | None = None,
|
362 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
363 |
+
test_split_ratio: float = 0.2,
|
364 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
|
365 |
+
val_split_ratio: float = 0.5,
|
366 |
+
seed: int | None = None,
|
367 |
+
) -> None:
|
368 |
+
super().__init__(
|
369 |
+
train_batch_size=train_batch_size,
|
370 |
+
eval_batch_size=eval_batch_size,
|
371 |
+
num_workers=num_workers,
|
372 |
+
image_size=image_size,
|
373 |
+
transform=transform,
|
374 |
+
train_transform=train_transform,
|
375 |
+
eval_transform=eval_transform,
|
376 |
+
test_split_mode=test_split_mode,
|
377 |
+
test_split_ratio=test_split_ratio,
|
378 |
+
val_split_mode=val_split_mode,
|
379 |
+
val_split_ratio=val_split_ratio,
|
380 |
+
seed=seed,
|
381 |
+
)
|
382 |
+
self._name = name
|
383 |
+
self.task = TaskType(task)
|
384 |
+
self.root = Path(root)
|
385 |
+
self.normal_dir = normal_dir
|
386 |
+
self.abnormal_dir = abnormal_dir
|
387 |
+
self.normal_test_dir = normal_test_dir
|
388 |
+
self.mask_dir = mask_dir
|
389 |
+
self.normal_depth_dir = normal_depth_dir
|
390 |
+
self.abnormal_depth_dir = abnormal_depth_dir
|
391 |
+
self.normal_test_depth_dir = normal_test_depth_dir
|
392 |
+
self.extensions = extensions
|
393 |
+
|
394 |
+
def _setup(self, _stage: str | None = None) -> None:
|
395 |
+
self.train_data = Folder3DDataset(
|
396 |
+
name=self.name,
|
397 |
+
task=self.task,
|
398 |
+
transform=self.train_transform,
|
399 |
+
split=Split.TRAIN,
|
400 |
+
root=self.root,
|
401 |
+
normal_dir=self.normal_dir,
|
402 |
+
abnormal_dir=self.abnormal_dir,
|
403 |
+
normal_test_dir=self.normal_test_dir,
|
404 |
+
mask_dir=self.mask_dir,
|
405 |
+
normal_depth_dir=self.normal_depth_dir,
|
406 |
+
abnormal_depth_dir=self.abnormal_depth_dir,
|
407 |
+
normal_test_depth_dir=self.normal_test_depth_dir,
|
408 |
+
extensions=self.extensions,
|
409 |
+
)
|
410 |
+
|
411 |
+
self.test_data = Folder3DDataset(
|
412 |
+
name=self.name,
|
413 |
+
task=self.task,
|
414 |
+
transform=self.eval_transform,
|
415 |
+
split=Split.TEST,
|
416 |
+
root=self.root,
|
417 |
+
normal_dir=self.normal_dir,
|
418 |
+
abnormal_dir=self.abnormal_dir,
|
419 |
+
normal_test_dir=self.normal_test_dir,
|
420 |
+
normal_depth_dir=self.normal_depth_dir,
|
421 |
+
abnormal_depth_dir=self.abnormal_depth_dir,
|
422 |
+
normal_test_depth_dir=self.normal_test_depth_dir,
|
423 |
+
mask_dir=self.mask_dir,
|
424 |
+
extensions=self.extensions,
|
425 |
+
)
|
426 |
+
|
427 |
+
@property
|
428 |
+
def name(self) -> str:
|
429 |
+
"""Name of the datamodule.
|
430 |
+
|
431 |
+
Folder3D datamodule overrides the name property to provide a custom name.
|
432 |
+
"""
|
433 |
+
return self._name
|
anomalib/data/depth/mvtec_3d.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MVTec 3D-AD Dataset (CC BY-NC-SA 4.0).
|
2 |
+
|
3 |
+
Description:
|
4 |
+
This script contains PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for the MVTec 3D-AD dataset.
|
5 |
+
If the dataset is not on the file system, the script downloads and extracts the dataset and create PyTorch data
|
6 |
+
objects.
|
7 |
+
|
8 |
+
License:
|
9 |
+
MVTec 3D-AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
10 |
+
License (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
11 |
+
|
12 |
+
Reference:
|
13 |
+
- Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly
|
14 |
+
Detection and Localization in: Proceedings of the 17th International Joint Conference on Computer Vision,
|
15 |
+
Imaging and Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022, DOI: 10.5220/
|
16 |
+
0010865000003124.
|
17 |
+
"""
|
18 |
+
|
19 |
+
# Copyright (C) 2022 Intel Corporation
|
20 |
+
# SPDX-License-Identifier: Apache-2.0
|
21 |
+
|
22 |
+
|
23 |
+
import logging
|
24 |
+
from collections.abc import Sequence
|
25 |
+
from pathlib import Path
|
26 |
+
|
27 |
+
from pandas import DataFrame
|
28 |
+
from torchvision.transforms.v2 import Transform
|
29 |
+
|
30 |
+
from anomalib import TaskType
|
31 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
|
32 |
+
from anomalib.data.errors import MisMatchError
|
33 |
+
from anomalib.data.utils import (
|
34 |
+
DownloadInfo,
|
35 |
+
LabelName,
|
36 |
+
Split,
|
37 |
+
TestSplitMode,
|
38 |
+
ValSplitMode,
|
39 |
+
download_and_extract,
|
40 |
+
validate_path,
|
41 |
+
)
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
IMG_EXTENSIONS = [".png", ".PNG", ".tiff"]
|
47 |
+
|
48 |
+
DOWNLOAD_INFO = DownloadInfo(
|
49 |
+
name="mvtec_3d",
|
50 |
+
url="https://www.mydrive.ch/shares/45920/dd1eb345346df066c63b5c95676b961b/download/428824485-1643285832"
|
51 |
+
"/mvtec_3d_anomaly_detection.tar.xz",
|
52 |
+
hashsum="d8bb2800fbf3ac88e798da6ae10dc819",
|
53 |
+
)
|
54 |
+
|
55 |
+
CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire")
|
56 |
+
|
57 |
+
|
58 |
+
def make_mvtec_3d_dataset(
|
59 |
+
root: str | Path,
|
60 |
+
split: str | Split | None = None,
|
61 |
+
extensions: Sequence[str] | None = None,
|
62 |
+
) -> DataFrame:
|
63 |
+
"""Create MVTec 3D-AD samples by parsing the MVTec AD data file structure.
|
64 |
+
|
65 |
+
The files are expected to follow this structure:
|
66 |
+
- `path/to/dataset/split/category/image_filename.png`
|
67 |
+
- `path/to/dataset/ground_truth/category/mask_filename.png`
|
68 |
+
|
69 |
+
This function creates a DataFrame to store the parsed information. The DataFrame follows this format:
|
70 |
+
|
71 |
+
+---+---------------+-------+---------+---------------+---------------------------------------+-------------+
|
72 |
+
| | path | split | label | image_path | mask_path | label_index |
|
73 |
+
+---+---------------+-------+---------+---------------+---------------------------------------+-------------+
|
74 |
+
| 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 |
|
75 |
+
+---+---------------+-------+---------+---------------+---------------------------------------+-------------+
|
76 |
+
|
77 |
+
Args:
|
78 |
+
root (Path): Path to the dataset.
|
79 |
+
split (str | Split | None, optional): Dataset split (e.g., 'train' or 'test').
|
80 |
+
Defaults to ``None``.
|
81 |
+
extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset.
|
82 |
+
Defaults to ``None``.
|
83 |
+
|
84 |
+
Examples:
|
85 |
+
The following example shows how to get training samples from the MVTec 3D-AD 'bagel' category:
|
86 |
+
|
87 |
+
>>> from pathlib import Path
|
88 |
+
>>> root = Path('./MVTec3D')
|
89 |
+
>>> category = 'bagel'
|
90 |
+
>>> path = root / category
|
91 |
+
>>> print(path)
|
92 |
+
PosixPath('MVTec3D/bagel')
|
93 |
+
|
94 |
+
>>> samples = create_mvtec_3d_ad_samples(path, split='train')
|
95 |
+
>>> print(samples.head())
|
96 |
+
path split label image_path mask_path label_index
|
97 |
+
MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/105.png MVTec3D/bagel/ground_truth/good/gt/105.png 0
|
98 |
+
MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/017.png MVTec3D/bagel/ground_truth/good/gt/017.png 0
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
DataFrame: An output DataFrame containing the samples of the dataset.
|
102 |
+
"""
|
103 |
+
if extensions is None:
|
104 |
+
extensions = IMG_EXTENSIONS
|
105 |
+
|
106 |
+
root = validate_path(root)
|
107 |
+
samples_list = [(str(root),) + f.parts[-4:] for f in root.glob(r"**/*") if f.suffix in extensions]
|
108 |
+
if not samples_list:
|
109 |
+
msg = f"Found 0 images in {root}"
|
110 |
+
raise RuntimeError(msg)
|
111 |
+
|
112 |
+
samples = DataFrame(samples_list, columns=["path", "split", "label", "type", "file_name"])
|
113 |
+
|
114 |
+
# Modify image_path column by converting to absolute path
|
115 |
+
samples.loc[(samples.type == "rgb"), "image_path"] = (
|
116 |
+
samples.path + "/" + samples.split + "/" + samples.label + "/" + "rgb/" + samples.file_name
|
117 |
+
)
|
118 |
+
samples.loc[(samples.type == "rgb"), "depth_path"] = (
|
119 |
+
samples.path
|
120 |
+
+ "/"
|
121 |
+
+ samples.split
|
122 |
+
+ "/"
|
123 |
+
+ samples.label
|
124 |
+
+ "/"
|
125 |
+
+ "xyz/"
|
126 |
+
+ samples.file_name.str.split(".").str[0]
|
127 |
+
+ ".tiff"
|
128 |
+
)
|
129 |
+
|
130 |
+
# Create label index for normal (0) and anomalous (1) images.
|
131 |
+
samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
|
132 |
+
samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
|
133 |
+
samples.label_index = samples.label_index.astype(int)
|
134 |
+
|
135 |
+
# separate masks from samples
|
136 |
+
mask_samples = samples.loc[((samples.split == "test") & (samples.type == "rgb"))].sort_values(
|
137 |
+
by="image_path",
|
138 |
+
ignore_index=True,
|
139 |
+
)
|
140 |
+
samples = samples.sort_values(by="image_path", ignore_index=True)
|
141 |
+
|
142 |
+
# assign mask paths to all test images
|
143 |
+
samples.loc[((samples.split == "test") & (samples.type == "rgb")), "mask_path"] = (
|
144 |
+
mask_samples.path + "/" + samples.split + "/" + samples.label + "/" + "gt/" + samples.file_name
|
145 |
+
)
|
146 |
+
samples = samples.dropna(subset=["image_path"])
|
147 |
+
samples = samples.astype({"image_path": "str", "mask_path": "str", "depth_path": "str"})
|
148 |
+
|
149 |
+
# assert that the right mask files are associated with the right test images
|
150 |
+
mismatch_masks = (
|
151 |
+
samples.loc[samples.label_index == LabelName.ABNORMAL]
|
152 |
+
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
|
153 |
+
.all()
|
154 |
+
)
|
155 |
+
if not mismatch_masks:
|
156 |
+
msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
|
157 |
+
in 'ground_truth' folder follow the same naming convention as the anomalous images in
|
158 |
+
the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
|
159 |
+
raise MisMatchError(msg)
|
160 |
+
|
161 |
+
mismatch_depth = (
|
162 |
+
samples.loc[samples.label_index == LabelName.ABNORMAL]
|
163 |
+
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
|
164 |
+
.all()
|
165 |
+
)
|
166 |
+
if not mismatch_depth:
|
167 |
+
msg = """Mismatch between anomalous images and depth images. Make sure the mask files in
|
168 |
+
'xyz' folder follow the same naming convention as the anomalous images in the dataset
|
169 |
+
(e.g. image: '000.png', depth: '000.tiff')."""
|
170 |
+
raise MisMatchError(msg)
|
171 |
+
|
172 |
+
if split:
|
173 |
+
samples = samples[samples.split == split].reset_index(drop=True)
|
174 |
+
|
175 |
+
return samples
|
176 |
+
|
177 |
+
|
178 |
+
class MVTec3DDataset(AnomalibDepthDataset):
|
179 |
+
"""MVTec 3D dataset class.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
|
183 |
+
root (Path | str): Path to the root of the dataset
|
184 |
+
Defaults to ``"./datasets/MVTec3D"``.
|
185 |
+
category (str): Sub-category of the dataset, e.g. 'bagel'
|
186 |
+
Defaults to ``"bagel"``.
|
187 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
188 |
+
Defaults to ``None``.
|
189 |
+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
|
190 |
+
Defaults to ``None``.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
task: TaskType,
|
196 |
+
root: Path | str = "./datasets/MVTec3D",
|
197 |
+
category: str = "bagel",
|
198 |
+
transform: Transform | None = None,
|
199 |
+
split: str | Split | None = None,
|
200 |
+
) -> None:
|
201 |
+
super().__init__(task=task, transform=transform)
|
202 |
+
|
203 |
+
self.root_category = Path(root) / Path(category)
|
204 |
+
self.split = split
|
205 |
+
self.samples = make_mvtec_3d_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS)
|
206 |
+
|
207 |
+
|
208 |
+
class MVTec3D(AnomalibDataModule):
|
209 |
+
"""MVTec Datamodule.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
root (Path | str): Path to the root of the dataset
|
213 |
+
Defaults to ``"./datasets/MVTec3D"``.
|
214 |
+
category (str): Category of the MVTec dataset (e.g. "bottle" or "cable").
|
215 |
+
Defaults to ``bagel``.
|
216 |
+
train_batch_size (int, optional): Training batch size.
|
217 |
+
Defaults to ``32``.
|
218 |
+
eval_batch_size (int, optional): Test batch size.
|
219 |
+
Defaults to ``32``.
|
220 |
+
num_workers (int, optional): Number of workers.
|
221 |
+
Defaults to ``8``.
|
222 |
+
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
|
223 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
224 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
225 |
+
Defaults to ``None``.
|
226 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
227 |
+
Defaults to ``None``.
|
228 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
229 |
+
Defaults to ``None``.
|
230 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
231 |
+
Defaults to ``None``.
|
232 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
233 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
234 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
235 |
+
Defaults to ``0.2``.
|
236 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
237 |
+
Defaults to ``ValSplitMode.SAME_AS_TEST``.
|
238 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
239 |
+
Defaults to ``0.5``.
|
240 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
241 |
+
Defaults to ``None``.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
root: Path | str = "./datasets/MVTec3D",
|
247 |
+
category: str = "bagel",
|
248 |
+
train_batch_size: int = 32,
|
249 |
+
eval_batch_size: int = 32,
|
250 |
+
num_workers: int = 8,
|
251 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
252 |
+
image_size: tuple[int, int] | None = None,
|
253 |
+
transform: Transform | None = None,
|
254 |
+
train_transform: Transform | None = None,
|
255 |
+
eval_transform: Transform | None = None,
|
256 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
257 |
+
test_split_ratio: float = 0.2,
|
258 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
|
259 |
+
val_split_ratio: float = 0.5,
|
260 |
+
seed: int | None = None,
|
261 |
+
) -> None:
|
262 |
+
super().__init__(
|
263 |
+
train_batch_size=train_batch_size,
|
264 |
+
eval_batch_size=eval_batch_size,
|
265 |
+
num_workers=num_workers,
|
266 |
+
image_size=image_size,
|
267 |
+
transform=transform,
|
268 |
+
train_transform=train_transform,
|
269 |
+
eval_transform=eval_transform,
|
270 |
+
test_split_mode=test_split_mode,
|
271 |
+
test_split_ratio=test_split_ratio,
|
272 |
+
val_split_mode=val_split_mode,
|
273 |
+
val_split_ratio=val_split_ratio,
|
274 |
+
seed=seed,
|
275 |
+
)
|
276 |
+
|
277 |
+
self.task = TaskType(task)
|
278 |
+
self.root = Path(root)
|
279 |
+
self.category = category
|
280 |
+
|
281 |
+
def _setup(self, _stage: str | None = None) -> None:
|
282 |
+
self.train_data = MVTec3DDataset(
|
283 |
+
task=self.task,
|
284 |
+
transform=self.train_transform,
|
285 |
+
split=Split.TRAIN,
|
286 |
+
root=self.root,
|
287 |
+
category=self.category,
|
288 |
+
)
|
289 |
+
self.test_data = MVTec3DDataset(
|
290 |
+
task=self.task,
|
291 |
+
transform=self.eval_transform,
|
292 |
+
split=Split.TEST,
|
293 |
+
root=self.root,
|
294 |
+
category=self.category,
|
295 |
+
)
|
296 |
+
|
297 |
+
def prepare_data(self) -> None:
|
298 |
+
"""Download the dataset if not available."""
|
299 |
+
if (self.root / self.category).is_dir():
|
300 |
+
logger.info("Found the dataset.")
|
301 |
+
else:
|
302 |
+
download_and_extract(self.root, DOWNLOAD_INFO)
|
anomalib/data/errors.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Exception Class for Mismatch Detection (MisMatchError)."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
class MisMatchError(Exception):
|
8 |
+
"""Exception raised when a mismatch is detected.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
message (str): Explanation of the error.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, message: str = "") -> None:
|
15 |
+
if message:
|
16 |
+
self.message = message
|
17 |
+
else:
|
18 |
+
self.message = "Mismatch detected."
|
19 |
+
super().__init__(self.message)
|
anomalib/data/image/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Anomalib Image Datasets.
|
2 |
+
|
3 |
+
This module contains the supported image datasets for Anomalib.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Copyright (C) 2024 Intel Corporation
|
7 |
+
# SPDX-License-Identifier: Apache-2.0
|
8 |
+
|
9 |
+
|
10 |
+
from enum import Enum
|
11 |
+
|
12 |
+
from .btech import BTech
|
13 |
+
from .folder import Folder
|
14 |
+
from .kolektor import Kolektor
|
15 |
+
from .mvtec import MVTec
|
16 |
+
from .mvtec_loco import MVTecLoco
|
17 |
+
from .visa import Visa
|
18 |
+
|
19 |
+
|
20 |
+
class ImageDataFormat(str, Enum):
|
21 |
+
"""Supported Image Dataset Types."""
|
22 |
+
|
23 |
+
MVTEC = "mvtec"
|
24 |
+
MVTEC_3D = "mvtec_3d"
|
25 |
+
MVTEC_LOCO = "mvtec_loco"
|
26 |
+
BTECH = "btech"
|
27 |
+
KOLEKTOR = "kolektor"
|
28 |
+
FOLDER = "folder"
|
29 |
+
FOLDER_3D = "folder_3d"
|
30 |
+
VISA = "visa"
|
31 |
+
|
32 |
+
|
33 |
+
__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"]
|
anomalib/data/image/btech.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""BTech Dataset.
|
2 |
+
|
3 |
+
This script contains PyTorch Lightning DataModule for the BTech dataset.
|
4 |
+
|
5 |
+
If the dataset is not on the file system, the script downloads and
|
6 |
+
extracts the dataset and create PyTorch data objects.
|
7 |
+
"""
|
8 |
+
|
9 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
10 |
+
# SPDX-License-Identifier: Apache-2.0
|
11 |
+
|
12 |
+
import logging
|
13 |
+
import shutil
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import pandas as pd
|
18 |
+
from pandas.core.frame import DataFrame
|
19 |
+
from torchvision.transforms.v2 import Transform
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
from anomalib import TaskType
|
23 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
24 |
+
from anomalib.data.utils import (
|
25 |
+
DownloadInfo,
|
26 |
+
LabelName,
|
27 |
+
Split,
|
28 |
+
TestSplitMode,
|
29 |
+
ValSplitMode,
|
30 |
+
download_and_extract,
|
31 |
+
validate_path,
|
32 |
+
)
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
DOWNLOAD_INFO = DownloadInfo(
|
37 |
+
name="btech",
|
38 |
+
url="https://avires.dimi.uniud.it/papers/btad/btad.zip",
|
39 |
+
hashsum="461c9387e515bfed41ecaae07c50cf6b10def647b36c9e31d239ab2736b10d2a",
|
40 |
+
)
|
41 |
+
|
42 |
+
CATEGORIES = ("01", "02", "03")
|
43 |
+
|
44 |
+
|
45 |
+
def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
|
46 |
+
"""Create BTech samples by parsing the BTech data file structure.
|
47 |
+
|
48 |
+
The files are expected to follow the structure:
|
49 |
+
|
50 |
+
.. code-block:: bash
|
51 |
+
|
52 |
+
path/to/dataset/split/category/image_filename.png
|
53 |
+
path/to/dataset/ground_truth/category/mask_filename.png
|
54 |
+
|
55 |
+
Args:
|
56 |
+
path (Path): Path to dataset
|
57 |
+
split (str | Split | None, optional): Dataset split (ie., either train or test).
|
58 |
+
Defaults to ``None``.
|
59 |
+
|
60 |
+
Example:
|
61 |
+
The following example shows how to get training samples from BTech 01 category:
|
62 |
+
|
63 |
+
.. code-block:: python
|
64 |
+
|
65 |
+
>>> root = Path('./BTech')
|
66 |
+
>>> category = '01'
|
67 |
+
>>> path = root / category
|
68 |
+
>>> path
|
69 |
+
PosixPath('BTech/01')
|
70 |
+
|
71 |
+
>>> samples = make_btech_dataset(path, split='train')
|
72 |
+
>>> samples.head()
|
73 |
+
path split label image_path mask_path label_index
|
74 |
+
0 BTech/01 train 01 BTech/01/train/ok/105.bmp BTech/01/ground_truth/ok/105.png 0
|
75 |
+
1 BTech/01 train 01 BTech/01/train/ok/017.bmp BTech/01/ground_truth/ok/017.png 0
|
76 |
+
...
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
DataFrame: an output dataframe containing samples for the requested split (ie., train or test)
|
80 |
+
"""
|
81 |
+
path = validate_path(path)
|
82 |
+
|
83 |
+
samples_list = [
|
84 |
+
(str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in (".bmp", ".png")
|
85 |
+
]
|
86 |
+
if not samples_list:
|
87 |
+
msg = f"Found 0 images in {path}"
|
88 |
+
raise RuntimeError(msg)
|
89 |
+
|
90 |
+
samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"])
|
91 |
+
samples = samples[samples.split != "ground_truth"]
|
92 |
+
|
93 |
+
# Create mask_path column
|
94 |
+
# (safely handles cases where non-mask image_paths end with either .png or .bmp)
|
95 |
+
samples["mask_path"] = (
|
96 |
+
samples.path
|
97 |
+
+ "/ground_truth/"
|
98 |
+
+ samples.label
|
99 |
+
+ "/"
|
100 |
+
+ samples.image_path.str.rstrip("png").str.rstrip(".").str.rstrip("bmp").str.rstrip(".")
|
101 |
+
+ ".png"
|
102 |
+
)
|
103 |
+
|
104 |
+
# Modify image_path column by converting to absolute path
|
105 |
+
samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path
|
106 |
+
|
107 |
+
# Good images don't have mask
|
108 |
+
samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = ""
|
109 |
+
|
110 |
+
# Create label index for normal (0) and anomalous (1) images.
|
111 |
+
samples.loc[(samples.label == "ok"), "label_index"] = LabelName.NORMAL
|
112 |
+
samples.loc[(samples.label != "ok"), "label_index"] = LabelName.ABNORMAL
|
113 |
+
samples.label_index = samples.label_index.astype(int)
|
114 |
+
|
115 |
+
# Get the data frame for the split.
|
116 |
+
if split:
|
117 |
+
samples = samples[samples.split == split]
|
118 |
+
samples = samples.reset_index(drop=True)
|
119 |
+
|
120 |
+
return samples
|
121 |
+
|
122 |
+
|
123 |
+
class BTechDataset(AnomalibDataset):
|
124 |
+
"""Btech Dataset class.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
root: Path to the BTech dataset
|
128 |
+
category: Name of the BTech category.
|
129 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
130 |
+
Defaults to ``None``.
|
131 |
+
split: 'train', 'val' or 'test'
|
132 |
+
task: ``classification``, ``detection`` or ``segmentation``
|
133 |
+
create_validation_set: Create a validation subset in addition to the train and test subsets
|
134 |
+
|
135 |
+
Examples:
|
136 |
+
>>> from anomalib.data.image.btech import BTechDataset
|
137 |
+
>>> from anomalib.data.utils.transforms import get_transforms
|
138 |
+
>>> transform = get_transforms(image_size=256)
|
139 |
+
>>> dataset = BTechDataset(
|
140 |
+
... task="classification",
|
141 |
+
... transform=transform,
|
142 |
+
... root='./datasets/BTech',
|
143 |
+
... category='01',
|
144 |
+
... )
|
145 |
+
>>> dataset[0].keys()
|
146 |
+
>>> dataset.setup()
|
147 |
+
dict_keys(['image'])
|
148 |
+
|
149 |
+
>>> dataset.split = "test"
|
150 |
+
>>> dataset[0].keys()
|
151 |
+
dict_keys(['image', 'image_path', 'label'])
|
152 |
+
|
153 |
+
>>> dataset.task = "segmentation"
|
154 |
+
>>> dataset.split = "train"
|
155 |
+
>>> dataset[0].keys()
|
156 |
+
dict_keys(['image'])
|
157 |
+
|
158 |
+
>>> dataset.split = "test"
|
159 |
+
>>> dataset[0].keys()
|
160 |
+
dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
|
161 |
+
|
162 |
+
>>> dataset[0]["image"].shape, dataset[0]["mask"].shape
|
163 |
+
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
root: str | Path,
|
169 |
+
category: str,
|
170 |
+
transform: Transform | None = None,
|
171 |
+
split: str | Split | None = None,
|
172 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
173 |
+
) -> None:
|
174 |
+
super().__init__(task, transform)
|
175 |
+
|
176 |
+
self.root_category = Path(root) / category
|
177 |
+
self.split = split
|
178 |
+
self.samples = make_btech_dataset(path=self.root_category, split=self.split)
|
179 |
+
|
180 |
+
|
181 |
+
class BTech(AnomalibDataModule):
|
182 |
+
"""BTech Lightning Data Module.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
root (Path | str): Path to the BTech dataset.
|
186 |
+
Defaults to ``"./datasets/BTech"``.
|
187 |
+
category (str): Name of the BTech category.
|
188 |
+
Defaults to ``"01"``.
|
189 |
+
train_batch_size (int, optional): Training batch size.
|
190 |
+
Defaults to ``32``.
|
191 |
+
eval_batch_size (int, optional): Eval batch size.
|
192 |
+
Defaults to ``32``.
|
193 |
+
num_workers (int, optional): Number of workers.
|
194 |
+
Defaults to ``8``.
|
195 |
+
task (TaskType, optional): Task type.
|
196 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
197 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
198 |
+
Defaults to ``None``.
|
199 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
200 |
+
Defaults to ``None``.
|
201 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
202 |
+
Defaults to ``None``.
|
203 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
204 |
+
Defaults to ``None``.
|
205 |
+
test_split_mode (TestSplitMode, optional): Setting that determines how the testing subset is obtained.
|
206 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
207 |
+
test_split_ratio (float, optional): Fraction of images from the train set that will be reserved for testing.
|
208 |
+
Defaults to ``0.2``.
|
209 |
+
val_split_mode (ValSplitMode, optional): Setting that determines how the validation subset is obtained.
|
210 |
+
Defaults to ``ValSplitMode.SAME_AS_TEST``.
|
211 |
+
val_split_ratio (float, optional): Fraction of train or test images that will be reserved for validation.
|
212 |
+
Defaults to ``0.5``.
|
213 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
214 |
+
Defaults to ``None``.
|
215 |
+
|
216 |
+
Examples:
|
217 |
+
To create the BTech datamodule, we need to instantiate the class, and call the ``setup`` method.
|
218 |
+
|
219 |
+
>>> from anomalib.data import BTech
|
220 |
+
>>> datamodule = BTech(
|
221 |
+
... root="./datasets/BTech",
|
222 |
+
... category="01",
|
223 |
+
... image_size=256,
|
224 |
+
... train_batch_size=32,
|
225 |
+
... eval_batch_size=32,
|
226 |
+
... num_workers=8,
|
227 |
+
... transform_config_train=None,
|
228 |
+
... transform_config_eval=None,
|
229 |
+
... )
|
230 |
+
>>> datamodule.setup()
|
231 |
+
|
232 |
+
To get the train dataloader and the first batch of data:
|
233 |
+
|
234 |
+
>>> i, data = next(enumerate(datamodule.train_dataloader()))
|
235 |
+
>>> data.keys()
|
236 |
+
dict_keys(['image'])
|
237 |
+
>>> data["image"].shape
|
238 |
+
torch.Size([32, 3, 256, 256])
|
239 |
+
|
240 |
+
To access the validation dataloader and the first batch of data:
|
241 |
+
|
242 |
+
>>> i, data = next(enumerate(datamodule.val_dataloader()))
|
243 |
+
>>> data.keys()
|
244 |
+
dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
|
245 |
+
>>> data["image"].shape, data["mask"].shape
|
246 |
+
(torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))
|
247 |
+
|
248 |
+
Similarly, to access the test dataloader and the first batch of data:
|
249 |
+
|
250 |
+
>>> i, data = next(enumerate(datamodule.test_dataloader()))
|
251 |
+
>>> data.keys()
|
252 |
+
dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask'])
|
253 |
+
>>> data["image"].shape, data["mask"].shape
|
254 |
+
(torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256]))
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
root: Path | str = "./datasets/BTech",
|
260 |
+
category: str = "01",
|
261 |
+
train_batch_size: int = 32,
|
262 |
+
eval_batch_size: int = 32,
|
263 |
+
num_workers: int = 8,
|
264 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
265 |
+
image_size: tuple[int, int] | None = None,
|
266 |
+
transform: Transform | None = None,
|
267 |
+
train_transform: Transform | None = None,
|
268 |
+
eval_transform: Transform | None = None,
|
269 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
270 |
+
test_split_ratio: float = 0.2,
|
271 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
|
272 |
+
val_split_ratio: float = 0.5,
|
273 |
+
seed: int | None = None,
|
274 |
+
) -> None:
|
275 |
+
super().__init__(
|
276 |
+
train_batch_size=train_batch_size,
|
277 |
+
eval_batch_size=eval_batch_size,
|
278 |
+
num_workers=num_workers,
|
279 |
+
image_size=image_size,
|
280 |
+
transform=transform,
|
281 |
+
train_transform=train_transform,
|
282 |
+
eval_transform=eval_transform,
|
283 |
+
test_split_mode=test_split_mode,
|
284 |
+
test_split_ratio=test_split_ratio,
|
285 |
+
val_split_mode=val_split_mode,
|
286 |
+
val_split_ratio=val_split_ratio,
|
287 |
+
seed=seed,
|
288 |
+
)
|
289 |
+
|
290 |
+
self.root = Path(root)
|
291 |
+
self.category = category
|
292 |
+
self.task = TaskType(task)
|
293 |
+
|
294 |
+
def _setup(self, _stage: str | None = None) -> None:
|
295 |
+
self.train_data = BTechDataset(
|
296 |
+
task=self.task,
|
297 |
+
transform=self.train_transform,
|
298 |
+
split=Split.TRAIN,
|
299 |
+
root=self.root,
|
300 |
+
category=self.category,
|
301 |
+
)
|
302 |
+
self.test_data = BTechDataset(
|
303 |
+
task=self.task,
|
304 |
+
transform=self.eval_transform,
|
305 |
+
split=Split.TEST,
|
306 |
+
root=self.root,
|
307 |
+
category=self.category,
|
308 |
+
)
|
309 |
+
|
310 |
+
def prepare_data(self) -> None:
|
311 |
+
"""Download the dataset if not available.
|
312 |
+
|
313 |
+
This method checks if the specified dataset is available in the file system.
|
314 |
+
If not, it downloads and extracts the dataset into the appropriate directory.
|
315 |
+
|
316 |
+
Example:
|
317 |
+
Assume the dataset is not available on the file system.
|
318 |
+
Here's how the directory structure looks before and after calling the
|
319 |
+
`prepare_data` method:
|
320 |
+
|
321 |
+
Before:
|
322 |
+
|
323 |
+
.. code-block:: bash
|
324 |
+
|
325 |
+
$ tree datasets
|
326 |
+
datasets
|
327 |
+
├── dataset1
|
328 |
+
└── dataset2
|
329 |
+
|
330 |
+
Calling the method:
|
331 |
+
|
332 |
+
.. code-block:: python
|
333 |
+
|
334 |
+
>> datamodule = BTech(root="./datasets/BTech", category="01")
|
335 |
+
>> datamodule.prepare_data()
|
336 |
+
|
337 |
+
After:
|
338 |
+
|
339 |
+
.. code-block:: bash
|
340 |
+
|
341 |
+
$ tree datasets
|
342 |
+
datasets
|
343 |
+
├── dataset1
|
344 |
+
├── dataset2
|
345 |
+
└── BTech
|
346 |
+
├── 01
|
347 |
+
├── 02
|
348 |
+
└── 03
|
349 |
+
"""
|
350 |
+
if (self.root / self.category).is_dir():
|
351 |
+
logger.info("Found the dataset.")
|
352 |
+
else:
|
353 |
+
download_and_extract(self.root.parent, DOWNLOAD_INFO)
|
354 |
+
|
355 |
+
# rename folder and convert images
|
356 |
+
logger.info("Renaming the dataset directory")
|
357 |
+
shutil.move(src=str(self.root.parent / "BTech_Dataset_transformed"), dst=str(self.root))
|
358 |
+
logger.info("Convert the bmp formats to png to have consistent image extensions")
|
359 |
+
for filename in tqdm(self.root.glob("**/*.bmp"), desc="Converting bmp to png"):
|
360 |
+
image = cv2.imread(str(filename))
|
361 |
+
cv2.imwrite(str(filename.with_suffix(".png")), image)
|
362 |
+
filename.unlink()
|
anomalib/data/image/folder.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Folder Dataset.
|
2 |
+
|
3 |
+
This script creates a custom dataset from a folder.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
7 |
+
# SPDX-License-Identifier: Apache-2.0
|
8 |
+
|
9 |
+
from collections.abc import Sequence
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from pandas import DataFrame
|
13 |
+
from torchvision.transforms.v2 import Transform
|
14 |
+
|
15 |
+
from anomalib import TaskType
|
16 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
17 |
+
from anomalib.data.errors import MisMatchError
|
18 |
+
from anomalib.data.utils import (
|
19 |
+
DirType,
|
20 |
+
LabelName,
|
21 |
+
Split,
|
22 |
+
TestSplitMode,
|
23 |
+
ValSplitMode,
|
24 |
+
)
|
25 |
+
from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path
|
26 |
+
|
27 |
+
|
28 |
+
def make_folder_dataset(
|
29 |
+
normal_dir: str | Path | Sequence[str | Path],
|
30 |
+
root: str | Path | None = None,
|
31 |
+
abnormal_dir: str | Path | Sequence[str | Path] | None = None,
|
32 |
+
normal_test_dir: str | Path | Sequence[str | Path] | None = None,
|
33 |
+
mask_dir: str | Path | Sequence[str | Path] | None = None,
|
34 |
+
split: str | Split | None = None,
|
35 |
+
extensions: tuple[str, ...] | None = None,
|
36 |
+
) -> DataFrame:
|
37 |
+
"""Make Folder Dataset.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
normal_dir (str | Path | Sequence): Path to the directory containing normal images.
|
41 |
+
root (str | Path | None): Path to the root directory of the dataset.
|
42 |
+
Defaults to ``None``.
|
43 |
+
abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images.
|
44 |
+
Defaults to ``None``.
|
45 |
+
normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing normal images for
|
46 |
+
the test dataset. Normal test images will be a split of `normal_dir` if `None`.
|
47 |
+
Defaults to ``None``.
|
48 |
+
mask_dir (str | Path | Sequence | None, optional): Path to the directory containing the mask annotations.
|
49 |
+
Defaults to ``None``.
|
50 |
+
split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST).
|
51 |
+
Defaults to ``None``.
|
52 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
|
53 |
+
Defaults to ``None``.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
DataFrame: an output dataframe containing samples for the requested split (ie., train or test).
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
Assume that we would like to use this ``make_folder_dataset`` to create a dataset from a folder.
|
60 |
+
We could then create the dataset as follows,
|
61 |
+
|
62 |
+
.. code-block:: python
|
63 |
+
|
64 |
+
folder_df = make_folder_dataset(
|
65 |
+
normal_dir=dataset_root / "good",
|
66 |
+
abnormal_dir=dataset_root / "crack",
|
67 |
+
split="train",
|
68 |
+
)
|
69 |
+
folder_df.head()
|
70 |
+
|
71 |
+
.. code-block:: bash
|
72 |
+
|
73 |
+
image_path label label_index mask_path split
|
74 |
+
0 ./toy/good/00.jpg DirType.NORMAL 0 Split.TRAIN
|
75 |
+
1 ./toy/good/01.jpg DirType.NORMAL 0 Split.TRAIN
|
76 |
+
2 ./toy/good/02.jpg DirType.NORMAL 0 Split.TRAIN
|
77 |
+
3 ./toy/good/03.jpg DirType.NORMAL 0 Split.TRAIN
|
78 |
+
4 ./toy/good/04.jpg DirType.NORMAL 0 Split.TRAIN
|
79 |
+
"""
|
80 |
+
|
81 |
+
def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | None) -> list[Path]:
|
82 |
+
"""Convert path to list of paths.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
path (str | Path | Sequence | None): Path to replace with Sequence[str | Path].
|
86 |
+
|
87 |
+
Examples:
|
88 |
+
>>> _resolve_path_and_convert_to_list("dir")
|
89 |
+
[Path("path/to/dir")]
|
90 |
+
>>> _resolve_path_and_convert_to_list(["dir1", "dir2"])
|
91 |
+
[Path("path/to/dir1"), Path("path/to/dir2")]
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
list[Path]: The result of path replaced by Sequence[str | Path].
|
95 |
+
"""
|
96 |
+
if isinstance(path, Sequence) and not isinstance(path, str):
|
97 |
+
return [validate_and_resolve_path(dir_path, root) for dir_path in path]
|
98 |
+
return [validate_and_resolve_path(path, root)] if path is not None else []
|
99 |
+
|
100 |
+
# All paths are changed to the List[Path] type and used.
|
101 |
+
normal_dir = _resolve_path_and_convert_to_list(normal_dir)
|
102 |
+
abnormal_dir = _resolve_path_and_convert_to_list(abnormal_dir)
|
103 |
+
normal_test_dir = _resolve_path_and_convert_to_list(normal_test_dir)
|
104 |
+
mask_dir = _resolve_path_and_convert_to_list(mask_dir)
|
105 |
+
if len(normal_dir) == 0:
|
106 |
+
msg = "A folder location must be provided in normal_dir."
|
107 |
+
raise ValueError(msg)
|
108 |
+
|
109 |
+
filenames = []
|
110 |
+
labels = []
|
111 |
+
dirs = {DirType.NORMAL: normal_dir}
|
112 |
+
|
113 |
+
if abnormal_dir:
|
114 |
+
dirs[DirType.ABNORMAL] = abnormal_dir
|
115 |
+
|
116 |
+
if normal_test_dir:
|
117 |
+
dirs[DirType.NORMAL_TEST] = normal_test_dir
|
118 |
+
|
119 |
+
if mask_dir:
|
120 |
+
dirs[DirType.MASK] = mask_dir
|
121 |
+
|
122 |
+
for dir_type, paths in dirs.items():
|
123 |
+
for path in paths:
|
124 |
+
filename, label = _prepare_files_labels(path, dir_type, extensions)
|
125 |
+
filenames += filename
|
126 |
+
labels += label
|
127 |
+
|
128 |
+
samples = DataFrame({"image_path": filenames, "label": labels})
|
129 |
+
samples = samples.sort_values(by="image_path", ignore_index=True)
|
130 |
+
|
131 |
+
# Create label index for normal (0) and abnormal (1) images.
|
132 |
+
samples.loc[
|
133 |
+
(samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST),
|
134 |
+
"label_index",
|
135 |
+
] = LabelName.NORMAL
|
136 |
+
samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
|
137 |
+
samples.label_index = samples.label_index.astype("Int64")
|
138 |
+
|
139 |
+
# If a path to mask is provided, add it to the sample dataframe.
|
140 |
+
|
141 |
+
if len(mask_dir) > 0 and len(abnormal_dir) > 0:
|
142 |
+
samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
|
143 |
+
samples.label == DirType.MASK
|
144 |
+
].image_path.to_numpy()
|
145 |
+
samples["mask_path"] = samples["mask_path"].fillna("")
|
146 |
+
samples = samples.astype({"mask_path": "str"})
|
147 |
+
|
148 |
+
# make sure all every rgb image has a corresponding mask image.
|
149 |
+
if not (
|
150 |
+
samples.loc[samples.label_index == LabelName.ABNORMAL]
|
151 |
+
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
|
152 |
+
.all()
|
153 |
+
):
|
154 |
+
msg = """Mismatch between anomalous images and mask images. Make sure the mask files "
|
155 |
+
"folder follow the same naming convention as the anomalous images in the dataset "
|
156 |
+
"(e.g. image: '000.png', mask: '000.png')."""
|
157 |
+
raise MisMatchError(msg)
|
158 |
+
|
159 |
+
else:
|
160 |
+
samples["mask_path"] = ""
|
161 |
+
|
162 |
+
# remove all the rows with temporal image samples that have already been assigned
|
163 |
+
samples = samples.loc[
|
164 |
+
(samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
|
165 |
+
]
|
166 |
+
|
167 |
+
# Ensure the pathlib objects are converted to str.
|
168 |
+
# This is because torch dataloader doesn't like pathlib.
|
169 |
+
samples = samples.astype({"image_path": "str"})
|
170 |
+
|
171 |
+
# Create train/test split.
|
172 |
+
# By default, all the normal samples are assigned as train.
|
173 |
+
# and all the abnormal samples are test.
|
174 |
+
samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
|
175 |
+
samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST
|
176 |
+
|
177 |
+
# Get the data frame for the split.
|
178 |
+
if split:
|
179 |
+
samples = samples[samples.split == split]
|
180 |
+
samples = samples.reset_index(drop=True)
|
181 |
+
|
182 |
+
return samples
|
183 |
+
|
184 |
+
|
185 |
+
class FolderDataset(AnomalibDataset):
|
186 |
+
"""Folder dataset.
|
187 |
+
|
188 |
+
This class is used to create a dataset from a folder. The class utilizes the Torch Dataset class.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
|
192 |
+
task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
|
193 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
194 |
+
Defaults to ``None``.
|
195 |
+
normal_dir (str | Path | Sequence): Path to the directory containing normal images.
|
196 |
+
root (str | Path | None): Root folder of the dataset.
|
197 |
+
Defaults to ``None``.
|
198 |
+
abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images.
|
199 |
+
Defaults to ``None``.
|
200 |
+
normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing
|
201 |
+
normal images for the test dataset.
|
202 |
+
Defaults to ``None``.
|
203 |
+
mask_dir (str | Path | Sequence | None, optional): Path to the directory containing
|
204 |
+
the mask annotations.
|
205 |
+
Defaults to ``None``.
|
206 |
+
split (str | Split | None): Fixed subset split that follows from folder structure on file system.
|
207 |
+
Choose from [Split.FULL, Split.TRAIN, Split.TEST]
|
208 |
+
Defaults to ``None``.
|
209 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory.
|
210 |
+
Defaults to ``None``.
|
211 |
+
|
212 |
+
Raises:
|
213 |
+
ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is
|
214 |
+
provided, `task` should be set to `segmentation`.
|
215 |
+
|
216 |
+
Examples:
|
217 |
+
Assume that we would like to use this ``FolderDataset`` to create a dataset from a folder for a classification
|
218 |
+
task. We could first create the transforms,
|
219 |
+
|
220 |
+
>>> from anomalib.data.utils import InputNormalizationMethod, get_transforms
|
221 |
+
>>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE)
|
222 |
+
|
223 |
+
We could then create the dataset as follows,
|
224 |
+
|
225 |
+
.. code-block:: python
|
226 |
+
|
227 |
+
folder_dataset_classification_train = FolderDataset(
|
228 |
+
normal_dir=dataset_root / "good",
|
229 |
+
abnormal_dir=dataset_root / "crack",
|
230 |
+
split="train",
|
231 |
+
transform=transform,
|
232 |
+
task=TaskType.CLASSIFICATION,
|
233 |
+
)
|
234 |
+
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
name: str,
|
240 |
+
task: TaskType,
|
241 |
+
normal_dir: str | Path | Sequence[str | Path],
|
242 |
+
transform: Transform | None = None,
|
243 |
+
root: str | Path | None = None,
|
244 |
+
abnormal_dir: str | Path | Sequence[str | Path] | None = None,
|
245 |
+
normal_test_dir: str | Path | Sequence[str | Path] | None = None,
|
246 |
+
mask_dir: str | Path | Sequence[str | Path] | None = None,
|
247 |
+
split: str | Split | None = None,
|
248 |
+
extensions: tuple[str, ...] | None = None,
|
249 |
+
) -> None:
|
250 |
+
super().__init__(task, transform)
|
251 |
+
|
252 |
+
self._name = name
|
253 |
+
self.split = split
|
254 |
+
self.root = root
|
255 |
+
self.normal_dir = normal_dir
|
256 |
+
self.abnormal_dir = abnormal_dir
|
257 |
+
self.normal_test_dir = normal_test_dir
|
258 |
+
self.mask_dir = mask_dir
|
259 |
+
self.extensions = extensions
|
260 |
+
|
261 |
+
self.samples = make_folder_dataset(
|
262 |
+
root=self.root,
|
263 |
+
normal_dir=self.normal_dir,
|
264 |
+
abnormal_dir=self.abnormal_dir,
|
265 |
+
normal_test_dir=self.normal_test_dir,
|
266 |
+
mask_dir=self.mask_dir,
|
267 |
+
split=self.split,
|
268 |
+
extensions=self.extensions,
|
269 |
+
)
|
270 |
+
|
271 |
+
@property
|
272 |
+
def name(self) -> str:
|
273 |
+
"""Name of the dataset.
|
274 |
+
|
275 |
+
Folder dataset overrides the name property to provide a custom name.
|
276 |
+
"""
|
277 |
+
return self._name
|
278 |
+
|
279 |
+
|
280 |
+
class Folder(AnomalibDataModule):
|
281 |
+
"""Folder DataModule.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
|
285 |
+
normal_dir (str | Path | Sequence): Name of the directory containing normal images.
|
286 |
+
root (str | Path | None): Path to the root folder containing normal and abnormal dirs.
|
287 |
+
Defaults to ``None``.
|
288 |
+
abnormal_dir (str | Path | None | Sequence): Name of the directory containing abnormal images.
|
289 |
+
Defaults to ``None``.
|
290 |
+
normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing
|
291 |
+
normal images for the test dataset.
|
292 |
+
Defaults to ``None``.
|
293 |
+
mask_dir (str | Path | Sequence | None, optional): Path to the directory containing
|
294 |
+
the mask annotations.
|
295 |
+
Defaults to ``None``.
|
296 |
+
normal_split_ratio (float, optional): Ratio to split normal training images and add to the
|
297 |
+
test set in case test set doesn't contain any normal images.
|
298 |
+
Defaults to 0.2.
|
299 |
+
extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the
|
300 |
+
directory.
|
301 |
+
Defaults to ``None``.
|
302 |
+
train_batch_size (int, optional): Training batch size.
|
303 |
+
Defaults to ``32``.
|
304 |
+
eval_batch_size (int, optional): Validation, test and predict batch size.
|
305 |
+
Defaults to ``32``.
|
306 |
+
num_workers (int, optional): Number of workers.
|
307 |
+
Defaults to ``8``.
|
308 |
+
task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``.
|
309 |
+
Defaults to ``segmentation``.
|
310 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
311 |
+
Defaults to ``None``.
|
312 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
313 |
+
Defaults to ``None``.
|
314 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
315 |
+
Defaults to ``None``.
|
316 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
317 |
+
Defaults to ``None``.
|
318 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
319 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
320 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
321 |
+
Defaults to ``0.2``.
|
322 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
323 |
+
Defaults to ``ValSplitMode.FROM_TEST``.
|
324 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
325 |
+
Defaults to ``0.5``.
|
326 |
+
seed (int | None, optional): Seed used during random subset splitting.
|
327 |
+
Defaults to ``None``.
|
328 |
+
|
329 |
+
Examples:
|
330 |
+
The following code demonstrates how to use the ``Folder`` datamodule. Assume that the dataset is structured
|
331 |
+
as follows:
|
332 |
+
|
333 |
+
.. code-block:: bash
|
334 |
+
|
335 |
+
$ tree sample_dataset
|
336 |
+
sample_dataset
|
337 |
+
├── colour
|
338 |
+
│ ├── 00.jpg
|
339 |
+
│ ├── ...
|
340 |
+
│ └── x.jpg
|
341 |
+
├── crack
|
342 |
+
│ ├── 00.jpg
|
343 |
+
│ ├── ...
|
344 |
+
│ └── y.jpg
|
345 |
+
├── good
|
346 |
+
│ ├── ...
|
347 |
+
│ └── z.jpg
|
348 |
+
├── LICENSE
|
349 |
+
└── mask
|
350 |
+
├── colour
|
351 |
+
│ ├── ...
|
352 |
+
│ └── x.jpg
|
353 |
+
└── crack
|
354 |
+
├── ...
|
355 |
+
└── y.jpg
|
356 |
+
|
357 |
+
.. code-block:: python
|
358 |
+
|
359 |
+
folder_datamodule = Folder(
|
360 |
+
root=dataset_root,
|
361 |
+
normal_dir="good",
|
362 |
+
abnormal_dir="crack",
|
363 |
+
task=TaskType.SEGMENTATION,
|
364 |
+
mask_dir=dataset_root / "mask" / "crack",
|
365 |
+
image_size=256,
|
366 |
+
normalization=InputNormalizationMethod.NONE,
|
367 |
+
)
|
368 |
+
folder_datamodule.setup()
|
369 |
+
|
370 |
+
To access the training images,
|
371 |
+
|
372 |
+
.. code-block:: python
|
373 |
+
|
374 |
+
>> i, data = next(enumerate(folder_datamodule.train_dataloader()))
|
375 |
+
>> print(data.keys(), data["image"].shape)
|
376 |
+
|
377 |
+
To access the test images,
|
378 |
+
|
379 |
+
.. code-block:: python
|
380 |
+
|
381 |
+
>> i, data = next(enumerate(folder_datamodule.test_dataloader()))
|
382 |
+
>> print(data.keys(), data["image"].shape)
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
name: str,
|
388 |
+
normal_dir: str | Path | Sequence[str | Path],
|
389 |
+
root: str | Path | None = None,
|
390 |
+
abnormal_dir: str | Path | Sequence[str | Path] | None = None,
|
391 |
+
normal_test_dir: str | Path | Sequence[str | Path] | None = None,
|
392 |
+
mask_dir: str | Path | Sequence[str | Path] | None = None,
|
393 |
+
normal_split_ratio: float = 0.2,
|
394 |
+
extensions: tuple[str] | None = None,
|
395 |
+
train_batch_size: int = 32,
|
396 |
+
eval_batch_size: int = 32,
|
397 |
+
num_workers: int = 8,
|
398 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
399 |
+
image_size: tuple[int, int] | None = None,
|
400 |
+
transform: Transform | None = None,
|
401 |
+
train_transform: Transform | None = None,
|
402 |
+
eval_transform: Transform | None = None,
|
403 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
404 |
+
test_split_ratio: float = 0.2,
|
405 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
|
406 |
+
val_split_ratio: float = 0.5,
|
407 |
+
seed: int | None = None,
|
408 |
+
) -> None:
|
409 |
+
self._name = name
|
410 |
+
self.root = root
|
411 |
+
self.normal_dir = normal_dir
|
412 |
+
self.abnormal_dir = abnormal_dir
|
413 |
+
self.normal_test_dir = normal_test_dir
|
414 |
+
self.mask_dir = mask_dir
|
415 |
+
self.task = TaskType(task)
|
416 |
+
self.extensions = extensions
|
417 |
+
test_split_mode = TestSplitMode(test_split_mode)
|
418 |
+
val_split_mode = ValSplitMode(val_split_mode)
|
419 |
+
super().__init__(
|
420 |
+
train_batch_size=train_batch_size,
|
421 |
+
eval_batch_size=eval_batch_size,
|
422 |
+
num_workers=num_workers,
|
423 |
+
test_split_mode=test_split_mode,
|
424 |
+
test_split_ratio=test_split_ratio,
|
425 |
+
val_split_mode=val_split_mode,
|
426 |
+
val_split_ratio=val_split_ratio,
|
427 |
+
image_size=image_size,
|
428 |
+
transform=transform,
|
429 |
+
train_transform=train_transform,
|
430 |
+
eval_transform=eval_transform,
|
431 |
+
seed=seed,
|
432 |
+
)
|
433 |
+
|
434 |
+
if task == TaskType.SEGMENTATION and test_split_mode == TestSplitMode.FROM_DIR and mask_dir is None:
|
435 |
+
msg = (
|
436 |
+
f"Segmentation task requires mask directory if test_split_mode is {test_split_mode}. "
|
437 |
+
"You could set test_split_mode to {TestSplitMode.NONE} or provide a mask directory."
|
438 |
+
)
|
439 |
+
raise ValueError(
|
440 |
+
msg,
|
441 |
+
)
|
442 |
+
|
443 |
+
self.normal_split_ratio = normal_split_ratio
|
444 |
+
|
445 |
+
def _setup(self, _stage: str | None = None) -> None:
|
446 |
+
self.train_data = FolderDataset(
|
447 |
+
name=self.name,
|
448 |
+
task=self.task,
|
449 |
+
transform=self.train_transform,
|
450 |
+
split=Split.TRAIN,
|
451 |
+
root=self.root,
|
452 |
+
normal_dir=self.normal_dir,
|
453 |
+
abnormal_dir=self.abnormal_dir,
|
454 |
+
normal_test_dir=self.normal_test_dir,
|
455 |
+
mask_dir=self.mask_dir,
|
456 |
+
extensions=self.extensions,
|
457 |
+
)
|
458 |
+
|
459 |
+
self.test_data = FolderDataset(
|
460 |
+
name=self.name,
|
461 |
+
task=self.task,
|
462 |
+
transform=self.eval_transform,
|
463 |
+
split=Split.TEST,
|
464 |
+
root=self.root,
|
465 |
+
normal_dir=self.normal_dir,
|
466 |
+
abnormal_dir=self.abnormal_dir,
|
467 |
+
normal_test_dir=self.normal_test_dir,
|
468 |
+
mask_dir=self.mask_dir,
|
469 |
+
extensions=self.extensions,
|
470 |
+
)
|
471 |
+
|
472 |
+
@property
|
473 |
+
def name(self) -> str:
|
474 |
+
"""Name of the datamodule.
|
475 |
+
|
476 |
+
Folder datamodule overrides the name property to provide a custom name.
|
477 |
+
"""
|
478 |
+
return self._name
|
anomalib/data/image/kolektor.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kolektor Surface-Defect Dataset (CC BY-NC-SA 4.0).
|
2 |
+
|
3 |
+
Description:
|
4 |
+
This script provides a PyTorch Dataset, DataLoader, and PyTorch Lightning DataModule for the Kolektor
|
5 |
+
Surface-Defect dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset <https://www.vicos.si/resources/kolektorsdd/>`_.
|
6 |
+
|
7 |
+
License:
|
8 |
+
The Kolektor Surface-Defect dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike
|
9 |
+
4.0 International License (CC BY-NC-SA 4.0). For more details, visit
|
10 |
+
`Creative Commons License <https://creativecommons.org/licenses/by-nc-sa/4.0/>`_.
|
11 |
+
|
12 |
+
Reference:
|
13 |
+
Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. "Segmentation-based deep-learning approach
|
14 |
+
for surface-defect detection." Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776.
|
15 |
+
"""
|
16 |
+
|
17 |
+
# Copyright (C) 2023-2024 Intel Corporation
|
18 |
+
# SPDX-License-Identifier: Apache-2.0
|
19 |
+
|
20 |
+
import logging
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
from cv2 import imread
|
25 |
+
from pandas import DataFrame
|
26 |
+
from sklearn.model_selection import train_test_split
|
27 |
+
from torchvision.transforms.v2 import Transform
|
28 |
+
|
29 |
+
from anomalib import TaskType
|
30 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
31 |
+
from anomalib.data.errors import MisMatchError
|
32 |
+
from anomalib.data.utils import (
|
33 |
+
DownloadInfo,
|
34 |
+
Split,
|
35 |
+
TestSplitMode,
|
36 |
+
ValSplitMode,
|
37 |
+
download_and_extract,
|
38 |
+
validate_path,
|
39 |
+
)
|
40 |
+
|
41 |
+
__all__ = ["Kolektor", "KolektorDataset", "make_kolektor_dataset"]
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
DOWNLOAD_INFO = DownloadInfo(
|
46 |
+
name="kolektor",
|
47 |
+
url="https://go.vicos.si/kolektorsdd",
|
48 |
+
hashsum="65dc621693418585de9c4467d1340ea7958a6181816f0dc2883a1e8b61f9d4dc",
|
49 |
+
filename="KolektorSDD.zip",
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def is_mask_anomalous(path: str) -> int:
|
54 |
+
"""Check if a mask shows defects.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
path (str): Path to the mask file.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
int: 1 if the mask shows defects, 0 otherwise.
|
61 |
+
|
62 |
+
Example:
|
63 |
+
Assume that the following image is a mask for a defective image.
|
64 |
+
Then the function will return 1.
|
65 |
+
|
66 |
+
>>> from anomalib.data.image.kolektor import is_mask_anomalous
|
67 |
+
>>> path = './KolektorSDD/kos01/Part0_label.bmp'
|
68 |
+
>>> is_mask_anomalous(path)
|
69 |
+
1
|
70 |
+
"""
|
71 |
+
img_arr = imread(path)
|
72 |
+
if np.all(img_arr == 0):
|
73 |
+
return 0
|
74 |
+
return 1
|
75 |
+
|
76 |
+
|
77 |
+
def make_kolektor_dataset(
|
78 |
+
root: str | Path,
|
79 |
+
train_split_ratio: float = 0.8,
|
80 |
+
split: str | Split | None = None,
|
81 |
+
) -> DataFrame:
|
82 |
+
"""Create Kolektor samples by parsing the Kolektor data file structure.
|
83 |
+
|
84 |
+
The files are expected to follow this structure:
|
85 |
+
- Image files: `path/to/dataset/item/image_filename.jpg`, `path/to/dataset/kos01/Part0.jpg`
|
86 |
+
- Mask files: `path/to/dataset/item/mask_filename.bmp`, `path/to/dataset/kos01/Part0_label.bmp`
|
87 |
+
|
88 |
+
This function creates a DataFrame to store the parsed information in the following format:
|
89 |
+
|
90 |
+
+---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
|
91 |
+
| | path | item | split | label | image_path | mask_path | label_index |
|
92 |
+
+---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
|
93 |
+
| 0 | KolektorSDD | kos01 | test | Bad | /path/to/image_file | /path/to/mask_file | 1 |
|
94 |
+
+---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+
|
95 |
+
|
96 |
+
Args:
|
97 |
+
root (Path): Path to the dataset.
|
98 |
+
train_split_ratio (float, optional): Ratio for splitting good images into train/test sets.
|
99 |
+
Defaults to ``0.8``.
|
100 |
+
split (str | Split | None, optional): Dataset split (either 'train' or 'test').
|
101 |
+
Defaults to ``None``.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
pandas.DataFrame: An output DataFrame containing the samples of the dataset.
|
105 |
+
|
106 |
+
Example:
|
107 |
+
The following example shows how to get training samples from the Kolektor Dataset:
|
108 |
+
|
109 |
+
>>> from pathlib import Path
|
110 |
+
>>> root = Path('./KolektorSDD/')
|
111 |
+
>>> samples = create_kolektor_samples(root, train_split_ratio=0.8)
|
112 |
+
>>> samples.head()
|
113 |
+
path item split label image_path mask_path label_index
|
114 |
+
0 KolektorSDD kos01 train Good KolektorSDD/kos01/Part0.jpg KolektorSDD/kos01/Part0_label.bmp 0
|
115 |
+
1 KolektorSDD kos01 train Good KolektorSDD/kos01/Part1.jpg KolektorSDD/kos01/Part1_label.bmp 0
|
116 |
+
2 KolektorSDD kos01 train Good KolektorSDD/kos01/Part2.jpg KolektorSDD/kos01/Part2_label.bmp 0
|
117 |
+
3 KolektorSDD kos01 test Good KolektorSDD/kos01/Part3.jpg KolektorSDD/kos01/Part3_label.bmp 0
|
118 |
+
4 KolektorSDD kos01 train Good KolektorSDD/kos01/Part4.jpg KolektorSDD/kos01/Part4_label.bmp 0
|
119 |
+
"""
|
120 |
+
root = validate_path(root)
|
121 |
+
|
122 |
+
# Get list of images and masks
|
123 |
+
samples_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".jpg"]
|
124 |
+
masks_list = [(str(root),) + f.parts[-2:] for f in root.glob(r"**/*") if f.suffix == ".bmp"]
|
125 |
+
|
126 |
+
if not samples_list:
|
127 |
+
msg = f"Found 0 images in {root}"
|
128 |
+
raise RuntimeError(msg)
|
129 |
+
|
130 |
+
# Create dataframes
|
131 |
+
samples = DataFrame(samples_list, columns=["path", "item", "image_path"])
|
132 |
+
masks = DataFrame(masks_list, columns=["path", "item", "image_path"])
|
133 |
+
|
134 |
+
# Modify image_path column by converting to absolute path
|
135 |
+
samples["image_path"] = samples.path + "/" + samples.item + "/" + samples.image_path
|
136 |
+
masks["image_path"] = masks.path + "/" + masks.item + "/" + masks.image_path
|
137 |
+
|
138 |
+
# Sort samples by image path
|
139 |
+
samples = samples.sort_values(by="image_path", ignore_index=True)
|
140 |
+
masks = masks.sort_values(by="image_path", ignore_index=True)
|
141 |
+
|
142 |
+
# Add mask paths for sample images
|
143 |
+
samples["mask_path"] = masks.image_path.to_numpy()
|
144 |
+
|
145 |
+
# Use is_good func to configure the label_index
|
146 |
+
samples["label_index"] = samples["mask_path"].apply(is_mask_anomalous)
|
147 |
+
samples.label_index = samples.label_index.astype(int)
|
148 |
+
|
149 |
+
# Use label indexes to label data
|
150 |
+
samples.loc[(samples.label_index == 0), "label"] = "Good"
|
151 |
+
samples.loc[(samples.label_index == 1), "label"] = "Bad"
|
152 |
+
|
153 |
+
# Add all 'Bad' samples to test set
|
154 |
+
samples.loc[(samples.label == "Bad"), "split"] = "test"
|
155 |
+
|
156 |
+
# Divide 'good' images to train/test on 0.8/0.2 ratio
|
157 |
+
train_samples, test_samples = train_test_split(
|
158 |
+
samples[samples.label == "Good"],
|
159 |
+
train_size=train_split_ratio,
|
160 |
+
random_state=42,
|
161 |
+
)
|
162 |
+
samples.loc[train_samples.index, "split"] = "train"
|
163 |
+
samples.loc[test_samples.index, "split"] = "test"
|
164 |
+
|
165 |
+
# Reorder columns
|
166 |
+
samples = samples[["path", "item", "split", "label", "image_path", "mask_path", "label_index"]]
|
167 |
+
|
168 |
+
# assert that the right mask files are associated with the right test images
|
169 |
+
if not (
|
170 |
+
samples.loc[samples.label_index == 1]
|
171 |
+
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
|
172 |
+
.all()
|
173 |
+
):
|
174 |
+
msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
|
175 |
+
follow the same naming convention as the anomalous images in the dataset
|
176 |
+
(e.g. image: 'Part0.jpg', mask: 'Part0_label.bmp')."""
|
177 |
+
raise MisMatchError(msg)
|
178 |
+
|
179 |
+
# Get the dataframe for the required split
|
180 |
+
if split:
|
181 |
+
samples = samples[samples.split == split].reset_index(drop=True)
|
182 |
+
|
183 |
+
return samples
|
184 |
+
|
185 |
+
|
186 |
+
class KolektorDataset(AnomalibDataset):
|
187 |
+
"""Kolektor dataset class.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
|
191 |
+
root (Path | str): Path to the root of the dataset
|
192 |
+
Defaults to ``./datasets/kolektor``.
|
193 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
194 |
+
Defaults to ``None``.
|
195 |
+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
|
196 |
+
Defaults to ``None``.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
task: TaskType,
|
202 |
+
root: Path | str = "./datasets/kolektor",
|
203 |
+
transform: Transform | None = None,
|
204 |
+
split: str | Split | None = None,
|
205 |
+
) -> None:
|
206 |
+
super().__init__(task=task, transform=transform)
|
207 |
+
|
208 |
+
self.root = root
|
209 |
+
self.split = split
|
210 |
+
self.samples = make_kolektor_dataset(self.root, train_split_ratio=0.8, split=self.split)
|
211 |
+
|
212 |
+
|
213 |
+
class Kolektor(AnomalibDataModule):
|
214 |
+
"""Kolektor Datamodule.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
root (Path | str): Path to the root of the dataset
|
218 |
+
train_batch_size (int, optional): Training batch size.
|
219 |
+
Defaults to ``32``.
|
220 |
+
eval_batch_size (int, optional): Test batch size.
|
221 |
+
Defaults to ``32``.
|
222 |
+
num_workers (int, optional): Number of workers.
|
223 |
+
Defaults to ``8``.
|
224 |
+
task TaskType): Task type, 'classification', 'detection' or 'segmentation'
|
225 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
226 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
227 |
+
Defaults to ``None``.
|
228 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
229 |
+
Defaults to ``None``.
|
230 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
231 |
+
Defaults to ``None``.
|
232 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
233 |
+
Defaults to ``None``.
|
234 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
235 |
+
Defaults to ``TestSplitMode.FROM_DIR``
|
236 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
237 |
+
Defaults to ``0.2``
|
238 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
239 |
+
Defaults to ``ValSplitMode.SAME_AS_TEST``
|
240 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
241 |
+
Defaults to ``0.5``
|
242 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
243 |
+
Defaults to ``None``.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
root: Path | str = "./datasets/kolektor",
|
249 |
+
train_batch_size: int = 32,
|
250 |
+
eval_batch_size: int = 32,
|
251 |
+
num_workers: int = 8,
|
252 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
253 |
+
image_size: tuple[int, int] | None = None,
|
254 |
+
transform: Transform | None = None,
|
255 |
+
train_transform: Transform | None = None,
|
256 |
+
eval_transform: Transform | None = None,
|
257 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
258 |
+
test_split_ratio: float = 0.2,
|
259 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
|
260 |
+
val_split_ratio: float = 0.5,
|
261 |
+
seed: int | None = None,
|
262 |
+
) -> None:
|
263 |
+
super().__init__(
|
264 |
+
train_batch_size=train_batch_size,
|
265 |
+
eval_batch_size=eval_batch_size,
|
266 |
+
num_workers=num_workers,
|
267 |
+
image_size=image_size,
|
268 |
+
transform=transform,
|
269 |
+
train_transform=train_transform,
|
270 |
+
eval_transform=eval_transform,
|
271 |
+
test_split_mode=test_split_mode,
|
272 |
+
test_split_ratio=test_split_ratio,
|
273 |
+
val_split_mode=val_split_mode,
|
274 |
+
val_split_ratio=val_split_ratio,
|
275 |
+
seed=seed,
|
276 |
+
)
|
277 |
+
|
278 |
+
self.task = TaskType(task)
|
279 |
+
self.root = Path(root)
|
280 |
+
|
281 |
+
def _setup(self, _stage: str | None = None) -> None:
|
282 |
+
self.train_data = KolektorDataset(
|
283 |
+
task=self.task,
|
284 |
+
transform=self.train_transform,
|
285 |
+
split=Split.TRAIN,
|
286 |
+
root=self.root,
|
287 |
+
)
|
288 |
+
self.test_data = KolektorDataset(
|
289 |
+
task=self.task,
|
290 |
+
transform=self.eval_transform,
|
291 |
+
split=Split.TEST,
|
292 |
+
root=self.root,
|
293 |
+
)
|
294 |
+
|
295 |
+
def prepare_data(self) -> None:
|
296 |
+
"""Download the dataset if not available.
|
297 |
+
|
298 |
+
This method checks if the specified dataset is available in the file system.
|
299 |
+
If not, it downloads and extracts the dataset into the appropriate directory.
|
300 |
+
|
301 |
+
Example:
|
302 |
+
Assume the dataset is not available on the file system.
|
303 |
+
Here's how the directory structure looks before and after calling the
|
304 |
+
`prepare_data` method:
|
305 |
+
|
306 |
+
Before:
|
307 |
+
|
308 |
+
.. code-block:: bash
|
309 |
+
|
310 |
+
$ tree datasets
|
311 |
+
datasets
|
312 |
+
├── dataset1
|
313 |
+
└── dataset2
|
314 |
+
|
315 |
+
Calling the method:
|
316 |
+
|
317 |
+
.. code-block:: python
|
318 |
+
|
319 |
+
>> datamodule = Kolektor(root="./datasets/kolektor")
|
320 |
+
>> datamodule.prepare_data()
|
321 |
+
|
322 |
+
After:
|
323 |
+
|
324 |
+
.. code-block:: bash
|
325 |
+
|
326 |
+
$ tree datasets
|
327 |
+
datasets
|
328 |
+
├── dataset1
|
329 |
+
├── dataset2
|
330 |
+
└── kolektor
|
331 |
+
├── kolektorsdd
|
332 |
+
├── kos01
|
333 |
+
├── ...
|
334 |
+
└── kos50
|
335 |
+
├── Part0.jpg
|
336 |
+
├── Part0_label.bmp
|
337 |
+
└── ...
|
338 |
+
"""
|
339 |
+
if (self.root).is_dir():
|
340 |
+
logger.info("Found the dataset.")
|
341 |
+
else:
|
342 |
+
download_and_extract(self.root, DOWNLOAD_INFO)
|
anomalib/data/image/mvtec.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MVTec AD Dataset (CC BY-NC-SA 4.0).
|
2 |
+
|
3 |
+
Description:
|
4 |
+
This script contains PyTorch Dataset, Dataloader and PyTorch Lightning
|
5 |
+
DataModule for the MVTec AD dataset. If the dataset is not on the file system,
|
6 |
+
the script downloads and extracts the dataset and create PyTorch data objects.
|
7 |
+
|
8 |
+
License:
|
9 |
+
MVTec AD dataset is released under the Creative Commons
|
10 |
+
Attribution-NonCommercial-ShareAlike 4.0 International License
|
11 |
+
(CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
12 |
+
|
13 |
+
References:
|
14 |
+
- Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger:
|
15 |
+
The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for
|
16 |
+
Unsupervised Anomaly Detection; in: International Journal of Computer Vision
|
17 |
+
129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4.
|
18 |
+
|
19 |
+
- Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD —
|
20 |
+
A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection;
|
21 |
+
in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
|
22 |
+
9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
26 |
+
# SPDX-License-Identifier: Apache-2.0
|
27 |
+
|
28 |
+
import logging
|
29 |
+
from collections.abc import Sequence
|
30 |
+
from pathlib import Path
|
31 |
+
|
32 |
+
from pandas import DataFrame
|
33 |
+
from torchvision.transforms.v2 import Transform
|
34 |
+
|
35 |
+
from anomalib import TaskType
|
36 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
37 |
+
from anomalib.data.errors import MisMatchError
|
38 |
+
from anomalib.data.utils import (
|
39 |
+
DownloadInfo,
|
40 |
+
LabelName,
|
41 |
+
Split,
|
42 |
+
TestSplitMode,
|
43 |
+
ValSplitMode,
|
44 |
+
download_and_extract,
|
45 |
+
validate_path,
|
46 |
+
)
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
IMG_EXTENSIONS = (".png", ".PNG")
|
52 |
+
|
53 |
+
DOWNLOAD_INFO = DownloadInfo(
|
54 |
+
name="mvtec",
|
55 |
+
url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094"
|
56 |
+
"/mvtec_anomaly_detection.tar.xz",
|
57 |
+
hashsum="cf4313b13603bec67abb49ca959488f7eedce2a9f7795ec54446c649ac98cd3d",
|
58 |
+
)
|
59 |
+
|
60 |
+
CATEGORIES = (
|
61 |
+
"bottle",
|
62 |
+
"cable",
|
63 |
+
"capsule",
|
64 |
+
"carpet",
|
65 |
+
"grid",
|
66 |
+
"hazelnut",
|
67 |
+
"leather",
|
68 |
+
"metal_nut",
|
69 |
+
"pill",
|
70 |
+
"screw",
|
71 |
+
"tile",
|
72 |
+
"toothbrush",
|
73 |
+
"transistor",
|
74 |
+
"wood",
|
75 |
+
"zipper",
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def make_mvtec_dataset(
|
80 |
+
root: str | Path,
|
81 |
+
split: str | Split | None = None,
|
82 |
+
extensions: Sequence[str] | None = None,
|
83 |
+
) -> DataFrame:
|
84 |
+
"""Create MVTec AD samples by parsing the MVTec AD data file structure.
|
85 |
+
|
86 |
+
The files are expected to follow the structure:
|
87 |
+
path/to/dataset/split/category/image_filename.png
|
88 |
+
path/to/dataset/ground_truth/category/mask_filename.png
|
89 |
+
|
90 |
+
This function creates a dataframe to store the parsed information based on the following format:
|
91 |
+
|
92 |
+
+---+---------------+-------+---------+---------------+---------------------------------------+-------------+
|
93 |
+
| | path | split | label | image_path | mask_path | label_index |
|
94 |
+
+===+===============+=======+=========+===============+=======================================+=============+
|
95 |
+
| 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 |
|
96 |
+
+---+---------------+-------+---------+---------------+---------------------------------------+-------------+
|
97 |
+
|
98 |
+
Args:
|
99 |
+
root (Path): Path to dataset
|
100 |
+
split (str | Split | None, optional): Dataset split (ie., either train or test).
|
101 |
+
Defaults to ``None``.
|
102 |
+
extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset.
|
103 |
+
Defaults to ``None``.
|
104 |
+
|
105 |
+
Examples:
|
106 |
+
The following example shows how to get training samples from MVTec AD bottle category:
|
107 |
+
|
108 |
+
>>> root = Path('./MVTec')
|
109 |
+
>>> category = 'bottle'
|
110 |
+
>>> path = root / category
|
111 |
+
>>> path
|
112 |
+
PosixPath('MVTec/bottle')
|
113 |
+
|
114 |
+
>>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0)
|
115 |
+
>>> samples.head()
|
116 |
+
path split label image_path mask_path label_index
|
117 |
+
0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0
|
118 |
+
1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0
|
119 |
+
2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0
|
120 |
+
3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0
|
121 |
+
4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
DataFrame: an output dataframe containing the samples of the dataset.
|
125 |
+
"""
|
126 |
+
if extensions is None:
|
127 |
+
extensions = IMG_EXTENSIONS
|
128 |
+
|
129 |
+
root = validate_path(root)
|
130 |
+
samples_list = [(str(root),) + f.parts[-3:] for f in root.glob(r"**/*") if f.suffix in extensions]
|
131 |
+
if not samples_list:
|
132 |
+
msg = f"Found 0 images in {root}"
|
133 |
+
raise RuntimeError(msg)
|
134 |
+
|
135 |
+
samples = DataFrame(samples_list, columns=["path", "split", "label", "image_path"])
|
136 |
+
|
137 |
+
# Modify image_path column by converting to absolute path
|
138 |
+
samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path
|
139 |
+
|
140 |
+
# Create label index for normal (0) and anomalous (1) images.
|
141 |
+
samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
|
142 |
+
samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
|
143 |
+
samples.label_index = samples.label_index.astype(int)
|
144 |
+
|
145 |
+
# separate masks from samples
|
146 |
+
mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True)
|
147 |
+
samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True)
|
148 |
+
|
149 |
+
# assign mask paths to anomalous test images
|
150 |
+
samples["mask_path"] = ""
|
151 |
+
samples.loc[
|
152 |
+
(samples.split == "test") & (samples.label_index == LabelName.ABNORMAL),
|
153 |
+
"mask_path",
|
154 |
+
] = mask_samples.image_path.to_numpy()
|
155 |
+
|
156 |
+
# assert that the right mask files are associated with the right test images
|
157 |
+
abnormal_samples = samples.loc[samples.label_index == LabelName.ABNORMAL]
|
158 |
+
if (
|
159 |
+
len(abnormal_samples)
|
160 |
+
and not abnormal_samples.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1).all()
|
161 |
+
):
|
162 |
+
msg = """Mismatch between anomalous images and ground truth masks. Make sure t
|
163 |
+
he mask files in 'ground_truth' folder follow the same naming convention as the
|
164 |
+
anomalous images in the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
|
165 |
+
raise MisMatchError(msg)
|
166 |
+
|
167 |
+
if split:
|
168 |
+
samples = samples[samples.split == split].reset_index(drop=True)
|
169 |
+
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
class MVTecDataset(AnomalibDataset):
|
174 |
+
"""MVTec dataset class.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
|
178 |
+
root (Path | str): Path to the root of the dataset.
|
179 |
+
Defaults to ``./datasets/MVTec``.
|
180 |
+
category (str): Sub-category of the dataset, e.g. 'bottle'
|
181 |
+
Defaults to ``bottle``.
|
182 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
183 |
+
Defaults to ``None``.
|
184 |
+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
|
185 |
+
Defaults to ``None``.
|
186 |
+
|
187 |
+
Examples:
|
188 |
+
.. code-block:: python
|
189 |
+
|
190 |
+
from anomalib.data.image.mvtec import MVTecDataset
|
191 |
+
from anomalib.data.utils.transforms import get_transforms
|
192 |
+
|
193 |
+
transform = get_transforms(image_size=256)
|
194 |
+
dataset = MVTecDataset(
|
195 |
+
task="classification",
|
196 |
+
transform=transform,
|
197 |
+
root='./datasets/MVTec',
|
198 |
+
category='zipper',
|
199 |
+
)
|
200 |
+
dataset.setup()
|
201 |
+
print(dataset[0].keys())
|
202 |
+
# Output: dict_keys(['image_path', 'label', 'image'])
|
203 |
+
|
204 |
+
When the task is segmentation, the dataset will also contain the mask:
|
205 |
+
|
206 |
+
.. code-block:: python
|
207 |
+
|
208 |
+
dataset.task = "segmentation"
|
209 |
+
dataset.setup()
|
210 |
+
print(dataset[0].keys())
|
211 |
+
# Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
|
212 |
+
|
213 |
+
The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W).
|
214 |
+
|
215 |
+
.. code-block:: python
|
216 |
+
|
217 |
+
print(dataset[0]["image"].shape, dataset[0]["mask"].shape)
|
218 |
+
# Output: (torch.Size([3, 256, 256]), torch.Size([256, 256]))
|
219 |
+
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
task: TaskType,
|
225 |
+
root: Path | str = "./datasets/MVTec",
|
226 |
+
category: str = "bottle",
|
227 |
+
transform: Transform | None = None,
|
228 |
+
split: str | Split | None = None,
|
229 |
+
) -> None:
|
230 |
+
super().__init__(task=task, transform=transform)
|
231 |
+
|
232 |
+
self.root_category = Path(root) / Path(category)
|
233 |
+
self.category = category
|
234 |
+
self.split = split
|
235 |
+
self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS)
|
236 |
+
|
237 |
+
|
238 |
+
class MVTec(AnomalibDataModule):
|
239 |
+
"""MVTec Datamodule.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
root (Path | str): Path to the root of the dataset.
|
243 |
+
Defaults to ``"./datasets/MVTec"``.
|
244 |
+
category (str): Category of the MVTec dataset (e.g. "bottle" or "cable").
|
245 |
+
Defaults to ``"bottle"``.
|
246 |
+
train_batch_size (int, optional): Training batch size.
|
247 |
+
Defaults to ``32``.
|
248 |
+
eval_batch_size (int, optional): Test batch size.
|
249 |
+
Defaults to ``32``.
|
250 |
+
num_workers (int, optional): Number of workers.
|
251 |
+
Defaults to ``8``.
|
252 |
+
task TaskType): Task type, 'classification', 'detection' or 'segmentation'
|
253 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
254 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
255 |
+
Defaults to ``None``.
|
256 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
257 |
+
Defaults to ``None``.
|
258 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
259 |
+
Defaults to ``None``.
|
260 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
261 |
+
Defaults to ``None``.
|
262 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
263 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
264 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
265 |
+
Defaults to ``0.2``.
|
266 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
267 |
+
Defaults to ``ValSplitMode.SAME_AS_TEST``.
|
268 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
269 |
+
Defaults to ``0.5``.
|
270 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
271 |
+
Defualts to ``None``.
|
272 |
+
|
273 |
+
Examples:
|
274 |
+
To create an MVTec AD datamodule with default settings:
|
275 |
+
|
276 |
+
>>> datamodule = MVTec()
|
277 |
+
>>> datamodule.setup()
|
278 |
+
>>> i, data = next(enumerate(datamodule.train_dataloader()))
|
279 |
+
>>> data.keys()
|
280 |
+
dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
|
281 |
+
|
282 |
+
>>> data["image"].shape
|
283 |
+
torch.Size([32, 3, 256, 256])
|
284 |
+
|
285 |
+
To change the category of the dataset:
|
286 |
+
|
287 |
+
>>> datamodule = MVTec(category="cable")
|
288 |
+
|
289 |
+
To change the image and batch size:
|
290 |
+
|
291 |
+
>>> datamodule = MVTec(image_size=(512, 512), train_batch_size=16, eval_batch_size=8)
|
292 |
+
|
293 |
+
MVTec AD dataset does not provide a validation set. If you would like
|
294 |
+
to use a separate validation set, you can use the ``val_split_mode`` and
|
295 |
+
``val_split_ratio`` arguments to create a validation set.
|
296 |
+
|
297 |
+
>>> datamodule = MVTec(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1)
|
298 |
+
|
299 |
+
This will subsample the test set by 10% and use it as the validation set.
|
300 |
+
If you would like to create a validation set synthetically that would
|
301 |
+
not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option.
|
302 |
+
|
303 |
+
>>> datamodule = MVTec(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2)
|
304 |
+
|
305 |
+
"""
|
306 |
+
|
307 |
+
def __init__(
|
308 |
+
self,
|
309 |
+
root: Path | str = "./datasets/MVTec",
|
310 |
+
category: str = "bottle",
|
311 |
+
train_batch_size: int = 32,
|
312 |
+
eval_batch_size: int = 32,
|
313 |
+
num_workers: int = 8,
|
314 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
315 |
+
image_size: tuple[int, int] | None = None,
|
316 |
+
transform: Transform | None = None,
|
317 |
+
train_transform: Transform | None = None,
|
318 |
+
eval_transform: Transform | None = None,
|
319 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
320 |
+
test_split_ratio: float = 0.2,
|
321 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
|
322 |
+
val_split_ratio: float = 0.5,
|
323 |
+
seed: int | None = None,
|
324 |
+
) -> None:
|
325 |
+
super().__init__(
|
326 |
+
train_batch_size=train_batch_size,
|
327 |
+
eval_batch_size=eval_batch_size,
|
328 |
+
image_size=image_size,
|
329 |
+
transform=transform,
|
330 |
+
train_transform=train_transform,
|
331 |
+
eval_transform=eval_transform,
|
332 |
+
num_workers=num_workers,
|
333 |
+
test_split_mode=test_split_mode,
|
334 |
+
test_split_ratio=test_split_ratio,
|
335 |
+
val_split_mode=val_split_mode,
|
336 |
+
val_split_ratio=val_split_ratio,
|
337 |
+
seed=seed,
|
338 |
+
)
|
339 |
+
|
340 |
+
self.task = TaskType(task)
|
341 |
+
self.root = Path(root)
|
342 |
+
self.category = category
|
343 |
+
|
344 |
+
def _setup(self, _stage: str | None = None) -> None:
|
345 |
+
"""Set up the datasets and perform dynamic subset splitting.
|
346 |
+
|
347 |
+
This method may be overridden in subclass for custom splitting behaviour.
|
348 |
+
|
349 |
+
Note:
|
350 |
+
The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule
|
351 |
+
subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset
|
352 |
+
splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and
|
353 |
+
the test set must therefore be created as early as the `fit` stage.
|
354 |
+
|
355 |
+
"""
|
356 |
+
self.train_data = MVTecDataset(
|
357 |
+
task=self.task,
|
358 |
+
transform=self.train_transform,
|
359 |
+
split=Split.TRAIN,
|
360 |
+
root=self.root,
|
361 |
+
category=self.category,
|
362 |
+
)
|
363 |
+
self.test_data = MVTecDataset(
|
364 |
+
task=self.task,
|
365 |
+
transform=self.eval_transform,
|
366 |
+
split=Split.TEST,
|
367 |
+
root=self.root,
|
368 |
+
category=self.category,
|
369 |
+
)
|
370 |
+
|
371 |
+
def prepare_data(self) -> None:
|
372 |
+
"""Download the dataset if not available.
|
373 |
+
|
374 |
+
This method checks if the specified dataset is available in the file system.
|
375 |
+
If not, it downloads and extracts the dataset into the appropriate directory.
|
376 |
+
|
377 |
+
Example:
|
378 |
+
Assume the dataset is not available on the file system.
|
379 |
+
Here's how the directory structure looks before and after calling the
|
380 |
+
`prepare_data` method:
|
381 |
+
|
382 |
+
Before:
|
383 |
+
|
384 |
+
.. code-block:: bash
|
385 |
+
|
386 |
+
$ tree datasets
|
387 |
+
datasets
|
388 |
+
├── dataset1
|
389 |
+
└── dataset2
|
390 |
+
|
391 |
+
Calling the method:
|
392 |
+
|
393 |
+
.. code-block:: python
|
394 |
+
|
395 |
+
>> datamodule = MVTec(root="./datasets/MVTec", category="bottle")
|
396 |
+
>> datamodule.prepare_data()
|
397 |
+
|
398 |
+
After:
|
399 |
+
|
400 |
+
.. code-block:: bash
|
401 |
+
|
402 |
+
$ tree datasets
|
403 |
+
datasets
|
404 |
+
├── dataset1
|
405 |
+
├── dataset2
|
406 |
+
└── MVTec
|
407 |
+
├── bottle
|
408 |
+
├── ...
|
409 |
+
└── zipper
|
410 |
+
"""
|
411 |
+
if (self.root / self.category).is_dir():
|
412 |
+
logger.info("Found the dataset.")
|
413 |
+
else:
|
414 |
+
download_and_extract(self.root, DOWNLOAD_INFO)
|
anomalib/data/image/mvtec_loco.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MVTec LOCO AD Dataset (CC BY-NC-SA 4.0).
|
2 |
+
|
3 |
+
Description:
|
4 |
+
This script contains PyTorch Dataset, Dataloader and PyTorch Lightning
|
5 |
+
DataModule for the MVTec LOCO AD dataset. If the dataset is not on the file system,
|
6 |
+
the script downloads and extracts the dataset and create PyTorch data objects.
|
7 |
+
|
8 |
+
License:
|
9 |
+
MVTec LOCO AD dataset is released under the Creative Commons
|
10 |
+
Attribution-NonCommercial-ShareAlike 4.0 International License
|
11 |
+
(CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
12 |
+
|
13 |
+
References:
|
14 |
+
- Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, and Carsten Steger:
|
15 |
+
Beyond Dents and Scratches: Logical Constraints in Unsupervised Anomaly Detection and Localization;
|
16 |
+
in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9
|
17 |
+
"""
|
18 |
+
|
19 |
+
# Copyright (C) 2024 Intel Corporation
|
20 |
+
# SPDX-License-Identifier: Apache-2.0
|
21 |
+
|
22 |
+
import logging
|
23 |
+
from collections.abc import Sequence
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from pandas import DataFrame
|
28 |
+
from PIL import Image
|
29 |
+
from torchvision.transforms.v2 import Transform
|
30 |
+
from torchvision.transforms.v2.functional import to_image
|
31 |
+
from torchvision.tv_tensors import Mask
|
32 |
+
|
33 |
+
from anomalib import TaskType
|
34 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
35 |
+
from anomalib.data.utils import (
|
36 |
+
DownloadInfo,
|
37 |
+
LabelName,
|
38 |
+
Split,
|
39 |
+
TestSplitMode,
|
40 |
+
ValSplitMode,
|
41 |
+
download_and_extract,
|
42 |
+
masks_to_boxes,
|
43 |
+
read_image,
|
44 |
+
validate_path,
|
45 |
+
)
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
IMG_EXTENSIONS = (".png", ".PNG")
|
51 |
+
|
52 |
+
DOWNLOAD_INFO = DownloadInfo(
|
53 |
+
name="mvtec_loco",
|
54 |
+
url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701"
|
55 |
+
"/mvtec_loco_anomaly_detection.tar.xz",
|
56 |
+
hashsum="9e7c84dba550fd2e59d8e9e231c929c45ba737b6b6a6d3814100f54d63aae687",
|
57 |
+
)
|
58 |
+
|
59 |
+
CATEGORIES = (
|
60 |
+
"breakfast_box",
|
61 |
+
"juice_bottle",
|
62 |
+
"pushpins",
|
63 |
+
"screw_bag",
|
64 |
+
"splicing_connectors",
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def make_mvtec_loco_dataset(
|
69 |
+
root: str | Path,
|
70 |
+
split: str | Split | None = None,
|
71 |
+
extensions: Sequence[str] = IMG_EXTENSIONS,
|
72 |
+
) -> DataFrame:
|
73 |
+
"""Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure.
|
74 |
+
|
75 |
+
The files are expected to follow the structure:
|
76 |
+
path/to/dataset/split/category/image_filename.png
|
77 |
+
path/to/dataset/ground_truth/category/image_filename/000.png
|
78 |
+
|
79 |
+
where there can be multiple ground-truth masks for the corresponding anomalous images.
|
80 |
+
|
81 |
+
This function creates a dataframe to store the parsed information based on the following format:
|
82 |
+
|
83 |
+
+---+---------------+-------+---------+-------------------------+-----------------------------+-------------+
|
84 |
+
| | path | split | label | image_path | mask_path | label_index |
|
85 |
+
+===+===============+=======+=========+===============+=======================================+=============+
|
86 |
+
| 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 |
|
87 |
+
+---+---------------+-------+---------+-------------------------+-----------------------------+-------------+
|
88 |
+
|
89 |
+
Args:
|
90 |
+
root (str | Path): Path to dataset
|
91 |
+
split (str | Split | None): Dataset split (ie., either train or test).
|
92 |
+
Defaults to ``None``.
|
93 |
+
extensions (Sequence[str]): List of file extensions to be included in the dataset.
|
94 |
+
Defaults to ``None``.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
DataFrame: an output dataframe containing the samples of the dataset.
|
98 |
+
|
99 |
+
Examples:
|
100 |
+
The following example shows how to get test samples from MVTec LOCO AD pushpins category:
|
101 |
+
|
102 |
+
>>> root = Path('./MVTec_LOCO')
|
103 |
+
>>> category = 'pushpins'
|
104 |
+
>>> path = root / category
|
105 |
+
>>> samples = make_mvtec_loco_dataset(path, split='test')
|
106 |
+
"""
|
107 |
+
root = validate_path(root)
|
108 |
+
|
109 |
+
# Retrieve the image and mask files
|
110 |
+
samples_list = []
|
111 |
+
for f in root.glob("**/*"):
|
112 |
+
if f.suffix in extensions:
|
113 |
+
parts = f.parts
|
114 |
+
# 'ground_truth' and non 'ground_truth' path have a different structure
|
115 |
+
if "ground_truth" not in parts:
|
116 |
+
split_folder, label_folder, image_file = parts[-3:]
|
117 |
+
image_path = f"{root}/{split_folder}/{label_folder}/{image_file}"
|
118 |
+
samples_list.append((str(root), split_folder, label_folder, "", image_path))
|
119 |
+
else:
|
120 |
+
split_folder, label_folder, image_folder, image_file = parts[-4:]
|
121 |
+
image_path = f"{root}/{split_folder}/{label_folder}/{image_folder}/{image_file}"
|
122 |
+
samples_list.append((str(root), split_folder, label_folder, image_folder, image_path))
|
123 |
+
|
124 |
+
if not samples_list:
|
125 |
+
msg = f"Found 0 images in {root}"
|
126 |
+
raise RuntimeError(msg)
|
127 |
+
|
128 |
+
samples = DataFrame(samples_list, columns=["path", "split", "label", "image_folder", "image_path"])
|
129 |
+
|
130 |
+
# Replace validation to Split.VAL.value in the split column
|
131 |
+
samples["split"] = samples["split"].replace("validation", Split.VAL.value)
|
132 |
+
|
133 |
+
# Create label index for normal (0) and anomalous (1) images.
|
134 |
+
samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
|
135 |
+
samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
|
136 |
+
samples.label_index = samples.label_index.astype(int)
|
137 |
+
|
138 |
+
# separate ground-truth masks from samples
|
139 |
+
mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True)
|
140 |
+
samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True)
|
141 |
+
|
142 |
+
# Group masks and aggregate the path into a list
|
143 |
+
mask_samples = (
|
144 |
+
mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"]
|
145 |
+
.agg(list)
|
146 |
+
.reset_index()
|
147 |
+
.rename(columns={"image_path": "mask_path"})
|
148 |
+
)
|
149 |
+
|
150 |
+
# assign mask paths to anomalous test images
|
151 |
+
samples["mask_path"] = ""
|
152 |
+
samples.loc[
|
153 |
+
(samples.split == "test") & (samples.label_index == LabelName.ABNORMAL),
|
154 |
+
"mask_path",
|
155 |
+
] = mask_samples.mask_path.to_numpy()
|
156 |
+
|
157 |
+
# validate that the right mask files are associated with the right test images
|
158 |
+
if len(samples.loc[samples.label_index == LabelName.ABNORMAL]):
|
159 |
+
image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem)
|
160 |
+
mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply(
|
161 |
+
lambda x: {Path(mask_path).parent.stem for mask_path in x},
|
162 |
+
)
|
163 |
+
|
164 |
+
if not all(
|
165 |
+
next(iter(mask_stems)) == image_stem
|
166 |
+
for image_stem, mask_stems in zip(image_stems, mask_parent_stems, strict=True)
|
167 |
+
):
|
168 |
+
error_message = (
|
169 |
+
"Mismatch between anomalous images and ground truth masks. "
|
170 |
+
"Make sure the parent folder of the mask files in 'ground_truth' folder "
|
171 |
+
"follows the same naming convention as the anomalous images in the dataset "
|
172 |
+
"(e.g., image: '005.png', mask: '005/000.png')."
|
173 |
+
)
|
174 |
+
raise ValueError(error_message)
|
175 |
+
|
176 |
+
if split:
|
177 |
+
samples = samples[samples.split == split].reset_index(drop=True)
|
178 |
+
|
179 |
+
return samples
|
180 |
+
|
181 |
+
|
182 |
+
class MVTecLocoDataset(AnomalibDataset):
|
183 |
+
"""MVTec LOCO dataset class.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
|
187 |
+
root (Path | str): Path to the root of the dataset.
|
188 |
+
Defaults to ``./datasets/MVTec_LOCO``.
|
189 |
+
category (str): Sub-category of the dataset, e.g. 'breakfast_box'
|
190 |
+
Defaults to ``breakfast_box``.
|
191 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
192 |
+
Defaults to ``None``.
|
193 |
+
split (str | Split | None): Split of the dataset, Split.TRAIN, Split.VAL, or Split.TEST
|
194 |
+
Defaults to ``None``.
|
195 |
+
|
196 |
+
Examples:
|
197 |
+
.. code-block:: python
|
198 |
+
|
199 |
+
from anomalib.data.image.mvtec_loco import MVTecLocoDataset
|
200 |
+
from anomalib.data.utils.transforms import get_transforms
|
201 |
+
from torchvision.transforms.v2 import Resize
|
202 |
+
|
203 |
+
transform = Resize((256, 256))
|
204 |
+
dataset = MVTecLocoDataset(
|
205 |
+
task="classification",
|
206 |
+
transform=transform,
|
207 |
+
root='./datasets/MVTec_LOCO',
|
208 |
+
category='breakfast_box',
|
209 |
+
)
|
210 |
+
dataset.setup()
|
211 |
+
print(dataset[0].keys())
|
212 |
+
# Output: dict_keys(['image_path', 'label', 'image'])
|
213 |
+
|
214 |
+
When the task is segmentation, the dataset will also contain the mask:
|
215 |
+
|
216 |
+
.. code-block:: python
|
217 |
+
|
218 |
+
dataset.task = "segmentation"
|
219 |
+
dataset.setup()
|
220 |
+
print(dataset[0].keys())
|
221 |
+
# Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
|
222 |
+
|
223 |
+
The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W).
|
224 |
+
|
225 |
+
.. code-block:: python
|
226 |
+
|
227 |
+
print(dataset[0]["image"].shape, dataset[0]["mask"].shape)
|
228 |
+
# Output: (torch.Size([3, 256, 256]), torch.Size([256, 256]))
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
task: TaskType,
|
234 |
+
root: Path | str = "./datasets/MVTec_LOCO",
|
235 |
+
category: str = "breakfast_box",
|
236 |
+
transform: Transform | None = None,
|
237 |
+
split: str | Split | None = None,
|
238 |
+
) -> None:
|
239 |
+
super().__init__(task=task, transform=transform)
|
240 |
+
|
241 |
+
self.root_category = Path(root) / category
|
242 |
+
self.split = split
|
243 |
+
self.samples = make_mvtec_loco_dataset(
|
244 |
+
self.root_category,
|
245 |
+
split=self.split,
|
246 |
+
extensions=IMG_EXTENSIONS,
|
247 |
+
)
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def _read_mask(mask_path: str | Path) -> Mask:
|
251 |
+
image = Image.open(mask_path).convert("L")
|
252 |
+
return Mask(to_image(image).squeeze(), dtype=torch.uint8)
|
253 |
+
|
254 |
+
def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
|
255 |
+
"""Get dataset item for the index ``index``.
|
256 |
+
|
257 |
+
This method is mostly based on the super class implementation, with some different as follows:
|
258 |
+
- Using 'torch.where' to make sure the 'mask' in the return item is binarized
|
259 |
+
- An additional 'masks' is added, the non-binary masks with original size for the SPRO metric calculation
|
260 |
+
Args:
|
261 |
+
index (int): Index to get the item.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
|
265 |
+
target path, image tensor, label and transformed bounding box.
|
266 |
+
"""
|
267 |
+
image_path = self.samples.iloc[index].image_path
|
268 |
+
mask_path = self.samples.iloc[index].mask_path
|
269 |
+
label_index = self.samples.iloc[index].label_index
|
270 |
+
|
271 |
+
image = read_image(image_path, as_tensor=True)
|
272 |
+
item = {"image_path": image_path, "label": label_index}
|
273 |
+
|
274 |
+
if self.task == TaskType.CLASSIFICATION:
|
275 |
+
item["image"] = self.transform(image) if self.transform else image
|
276 |
+
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
|
277 |
+
# Only Anomalous (1) images have masks in anomaly datasets
|
278 |
+
# Therefore, create empty mask for Normal (0) images.
|
279 |
+
if isinstance(mask_path, str):
|
280 |
+
mask_path = [mask_path]
|
281 |
+
semantic_mask = (
|
282 |
+
Mask(torch.zeros(image.shape[-2:])).to(torch.uint8)
|
283 |
+
if label_index == LabelName.NORMAL
|
284 |
+
else Mask(torch.stack([self._read_mask(path) for path in mask_path]))
|
285 |
+
)
|
286 |
+
|
287 |
+
binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8))
|
288 |
+
item["image"], item["mask"] = self.transform(image, binary_mask) if self.transform else (image, binary_mask)
|
289 |
+
|
290 |
+
item["mask_path"] = mask_path
|
291 |
+
# List of masks with the original size for saturation based metrics calculation
|
292 |
+
item["semantic_mask"] = semantic_mask
|
293 |
+
|
294 |
+
if self.task == TaskType.DETECTION:
|
295 |
+
# create boxes from masks for detection task
|
296 |
+
boxes, _ = masks_to_boxes(item["mask"])
|
297 |
+
item["boxes"] = boxes[0]
|
298 |
+
else:
|
299 |
+
msg = f"Unknown task type: {self.task}"
|
300 |
+
raise ValueError(msg)
|
301 |
+
|
302 |
+
return item
|
303 |
+
|
304 |
+
|
305 |
+
class MVTecLoco(AnomalibDataModule):
|
306 |
+
"""MVTec LOCO Datamodule.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
root (Path | str): Path to the root of the dataset.
|
310 |
+
Defaults to ``"./datasets/MVTec_LOCO"``.
|
311 |
+
category (str): Category of the MVTec LOCO dataset (e.g. "breakfast_box").
|
312 |
+
Defaults to ``"breakfast_box"``.
|
313 |
+
train_batch_size (int, optional): Training batch size.
|
314 |
+
Defaults to ``32``.
|
315 |
+
eval_batch_size (int, optional): Test batch size.
|
316 |
+
Defaults to ``32``.
|
317 |
+
num_workers (int, optional): Number of workers.
|
318 |
+
Defaults to ``8``.
|
319 |
+
task TaskType): Task type, 'classification', 'detection' or 'segmentation'
|
320 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
321 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
322 |
+
Defaults to ``None``.
|
323 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
324 |
+
Defaults to ``None``.
|
325 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
326 |
+
Defaults to ``None``.
|
327 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
328 |
+
Defaults to ``None``.
|
329 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
330 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
331 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
332 |
+
Defaults to ``0.2``.
|
333 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
334 |
+
Defaults to ``ValSplitMode.FROM_DIR``.
|
335 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
336 |
+
Defaults to ``0.5``.
|
337 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
338 |
+
Defaults to ``None``.
|
339 |
+
|
340 |
+
Examples:
|
341 |
+
To create an MVTec LOCO AD datamodule with default settings:
|
342 |
+
|
343 |
+
>>> datamodule = MVTecLoco(root="anomalib/datasets/MVTec_LOCO")
|
344 |
+
>>> datamodule.setup()
|
345 |
+
>>> i, data = next(enumerate(datamodule.train_dataloader()))
|
346 |
+
>>> data.keys()
|
347 |
+
dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
|
348 |
+
|
349 |
+
>>> data["image"].shape
|
350 |
+
torch.Size([32, 3, 256, 256])
|
351 |
+
|
352 |
+
To change the category of the dataset:
|
353 |
+
|
354 |
+
>>> datamodule = MVTecLoco(category="pushpins")
|
355 |
+
|
356 |
+
To change the image and batch size:
|
357 |
+
|
358 |
+
>>> datamodule = MVTecLoco(image_size=(512, 512), train_batch_size=16, eval_batch_size=8)
|
359 |
+
|
360 |
+
MVTec LOCO AD dataset provide an independent validation set with normal images only in the 'validation' folder.
|
361 |
+
If you would like to use a different validation set splitted from train or test set,
|
362 |
+
you can use the ``val_split_mode`` and ``val_split_ratio`` arguments to create a new validation set.
|
363 |
+
|
364 |
+
>>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1)
|
365 |
+
|
366 |
+
This will subsample the test set by 10% and use it as the validation set.
|
367 |
+
If you would like to create a validation set synthetically that would
|
368 |
+
not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option.
|
369 |
+
|
370 |
+
>>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2)
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
root: Path | str = "./datasets/MVTec_LOCO",
|
376 |
+
category: str = "breakfast_box",
|
377 |
+
train_batch_size: int = 32,
|
378 |
+
eval_batch_size: int = 32,
|
379 |
+
num_workers: int = 8,
|
380 |
+
task: TaskType = TaskType.SEGMENTATION,
|
381 |
+
image_size: tuple[int, int] | None = None,
|
382 |
+
transform: Transform | None = None,
|
383 |
+
train_transform: Transform | None = None,
|
384 |
+
eval_transform: Transform | None = None,
|
385 |
+
test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR,
|
386 |
+
test_split_ratio: float = 0.2,
|
387 |
+
val_split_mode: ValSplitMode = ValSplitMode.FROM_DIR,
|
388 |
+
val_split_ratio: float = 0.5,
|
389 |
+
seed: int | None = None,
|
390 |
+
) -> None:
|
391 |
+
super().__init__(
|
392 |
+
train_batch_size=train_batch_size,
|
393 |
+
eval_batch_size=eval_batch_size,
|
394 |
+
image_size=image_size,
|
395 |
+
transform=transform,
|
396 |
+
train_transform=train_transform,
|
397 |
+
eval_transform=eval_transform,
|
398 |
+
num_workers=num_workers,
|
399 |
+
test_split_mode=test_split_mode,
|
400 |
+
test_split_ratio=test_split_ratio,
|
401 |
+
val_split_mode=val_split_mode,
|
402 |
+
val_split_ratio=val_split_ratio,
|
403 |
+
seed=seed,
|
404 |
+
)
|
405 |
+
self.task = task
|
406 |
+
self.root = Path(root)
|
407 |
+
self.category = category
|
408 |
+
|
409 |
+
def _setup(self, _stage: str | None = None) -> None:
|
410 |
+
"""Set up the datasets, configs, and perform dynamic subset splitting.
|
411 |
+
|
412 |
+
This method overrides the parent class's method to also setup the val dataset.
|
413 |
+
The MVTec LOCO dataset provides an independent validation subset.
|
414 |
+
"""
|
415 |
+
self.train_data = MVTecLocoDataset(
|
416 |
+
task=self.task,
|
417 |
+
transform=self.train_transform,
|
418 |
+
split=Split.TRAIN,
|
419 |
+
root=self.root,
|
420 |
+
category=self.category,
|
421 |
+
)
|
422 |
+
self.val_data = MVTecLocoDataset(
|
423 |
+
task=self.task,
|
424 |
+
transform=self.eval_transform,
|
425 |
+
split=Split.VAL,
|
426 |
+
root=self.root,
|
427 |
+
category=self.category,
|
428 |
+
)
|
429 |
+
self.test_data = MVTecLocoDataset(
|
430 |
+
task=self.task,
|
431 |
+
transform=self.eval_transform,
|
432 |
+
split=Split.TEST,
|
433 |
+
root=self.root,
|
434 |
+
category=self.category,
|
435 |
+
)
|
436 |
+
|
437 |
+
def prepare_data(self) -> None:
|
438 |
+
"""Download the dataset if not available.
|
439 |
+
|
440 |
+
This method checks if the specified dataset is available in the file system.
|
441 |
+
If not, it downloads and extracts the dataset into the appropriate directory.
|
442 |
+
|
443 |
+
Example:
|
444 |
+
Assume the dataset is not available on the file system.
|
445 |
+
Here's how the directory structure looks before and after calling the
|
446 |
+
`prepare_data` method:
|
447 |
+
|
448 |
+
Before:
|
449 |
+
|
450 |
+
.. code-block:: bash
|
451 |
+
|
452 |
+
$ tree datasets
|
453 |
+
datasets
|
454 |
+
├── dataset1
|
455 |
+
└── dataset2
|
456 |
+
|
457 |
+
Calling the method:
|
458 |
+
|
459 |
+
.. code-block:: python
|
460 |
+
|
461 |
+
>> datamodule = MVTecLoco(root="./datasets/MVTec_LOCO", category="breakfast_box")
|
462 |
+
>> datamodule.prepare_data()
|
463 |
+
|
464 |
+
After:
|
465 |
+
|
466 |
+
.. code-block:: bash
|
467 |
+
|
468 |
+
$ tree datasets
|
469 |
+
datasets
|
470 |
+
├── dataset1
|
471 |
+
├── dataset2
|
472 |
+
└── MVTec_LOCO
|
473 |
+
├── breakfast_box
|
474 |
+
├── ...
|
475 |
+
└── splicing_connectors
|
476 |
+
"""
|
477 |
+
if (self.root / self.category).is_dir():
|
478 |
+
logger.info("Found the dataset.")
|
479 |
+
else:
|
480 |
+
download_and_extract(self.root, DOWNLOAD_INFO)
|
anomalib/data/image/visa.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Visual Anomaly (VisA) Dataset (CC BY-NC-SA 4.0).
|
2 |
+
|
3 |
+
Description:
|
4 |
+
This script contains PyTorch Dataset, Dataloader and PyTorch
|
5 |
+
Lightning DataModule for the Visual Anomal (VisA) dataset.
|
6 |
+
If the dataset is not on the file system, the script downloads and
|
7 |
+
extracts the dataset and create PyTorch data objects.
|
8 |
+
License:
|
9 |
+
The VisA dataset is released under the Creative Commons
|
10 |
+
Attribution-NonCommercial-ShareAlike 4.0 International License
|
11 |
+
(CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
12 |
+
Reference:
|
13 |
+
- Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). SPot-the-Difference
|
14 |
+
Self-supervised Pre-training for Anomaly Detection and Segmentation. In European
|
15 |
+
Conference on Computer Vision (pp. 392-408). Springer, Cham.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
19 |
+
# SPDX-License-Identifier: Apache-2.0
|
20 |
+
|
21 |
+
# Subset splitting code adapted from https://github.com/amazon-science/spot-diff
|
22 |
+
# Original licence: Apache-2.0
|
23 |
+
|
24 |
+
|
25 |
+
import csv
|
26 |
+
import logging
|
27 |
+
import shutil
|
28 |
+
from pathlib import Path
|
29 |
+
|
30 |
+
import cv2
|
31 |
+
from torchvision.transforms.v2 import Transform
|
32 |
+
|
33 |
+
from anomalib import TaskType
|
34 |
+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
|
35 |
+
from anomalib.data.utils import (
|
36 |
+
DownloadInfo,
|
37 |
+
Split,
|
38 |
+
TestSplitMode,
|
39 |
+
ValSplitMode,
|
40 |
+
download_and_extract,
|
41 |
+
)
|
42 |
+
|
43 |
+
from .mvtec import make_mvtec_dataset
|
44 |
+
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
EXTENSIONS = (".png", ".jpg", ".JPG")
|
48 |
+
|
49 |
+
DOWNLOAD_INFO = DownloadInfo(
|
50 |
+
name="VisA",
|
51 |
+
url="https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar",
|
52 |
+
hashsum="2eb8690c803ab37de0324772964100169ec8ba1fa3f7e94291c9ca673f40f362",
|
53 |
+
)
|
54 |
+
|
55 |
+
CATEGORIES = (
|
56 |
+
"candle",
|
57 |
+
"capsules",
|
58 |
+
"cashew",
|
59 |
+
"chewinggum",
|
60 |
+
"fryum",
|
61 |
+
"macaroni1",
|
62 |
+
"macaroni2",
|
63 |
+
"pcb1",
|
64 |
+
"pcb2",
|
65 |
+
"pcb3",
|
66 |
+
"pcb4",
|
67 |
+
"pipe_fryum",
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
class VisaDataset(AnomalibDataset):
|
72 |
+
"""VisA dataset class.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
|
76 |
+
root (str | Path): Path to the root of the dataset
|
77 |
+
category (str): Sub-category of the dataset, e.g. 'candle'
|
78 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
79 |
+
Defaults to ``None``.
|
80 |
+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
|
81 |
+
Defaults to ``None``.
|
82 |
+
|
83 |
+
Examples:
|
84 |
+
To create a Visa dataset for classification:
|
85 |
+
|
86 |
+
.. code-block:: python
|
87 |
+
|
88 |
+
from anomalib.data.image.visa import VisaDataset
|
89 |
+
from anomalib.data.utils.transforms import get_transforms
|
90 |
+
|
91 |
+
transform = get_transforms(image_size=256)
|
92 |
+
dataset = VisaDataset(
|
93 |
+
task="classification",
|
94 |
+
transform=transform,
|
95 |
+
split="train",
|
96 |
+
root="./datasets/visa/visa_pytorch/",
|
97 |
+
category="candle",
|
98 |
+
)
|
99 |
+
dataset.setup()
|
100 |
+
dataset[0].keys()
|
101 |
+
|
102 |
+
# Output
|
103 |
+
dict_keys(['image_path', 'label', 'image'])
|
104 |
+
|
105 |
+
If you want to use the dataset for segmentation, you can use the same
|
106 |
+
code as above, with the task set to ``segmentation``. The dataset will
|
107 |
+
then have a ``mask`` key in the output dictionary.
|
108 |
+
|
109 |
+
.. code-block:: python
|
110 |
+
|
111 |
+
from anomalib.data.image.visa import VisaDataset
|
112 |
+
from anomalib.data.utils.transforms import get_transforms
|
113 |
+
|
114 |
+
transform = get_transforms(image_size=256)
|
115 |
+
dataset = VisaDataset(
|
116 |
+
task="segmentation",
|
117 |
+
transform=transform,
|
118 |
+
split="train",
|
119 |
+
root="./datasets/visa/visa_pytorch/",
|
120 |
+
category="candle",
|
121 |
+
)
|
122 |
+
dataset.setup()
|
123 |
+
dataset[0].keys()
|
124 |
+
|
125 |
+
# Output
|
126 |
+
dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask'])
|
127 |
+
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
task: TaskType,
|
133 |
+
root: str | Path,
|
134 |
+
category: str,
|
135 |
+
transform: Transform | None = None,
|
136 |
+
split: str | Split | None = None,
|
137 |
+
) -> None:
|
138 |
+
super().__init__(task=task, transform=transform)
|
139 |
+
|
140 |
+
self.root_category = Path(root) / category
|
141 |
+
self.split = split
|
142 |
+
self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=EXTENSIONS)
|
143 |
+
|
144 |
+
|
145 |
+
class Visa(AnomalibDataModule):
|
146 |
+
"""VisA Datamodule.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
root (Path | str): Path to the root of the dataset
|
150 |
+
Defaults to ``"./datasets/visa"``.
|
151 |
+
category (str): Category of the Visa dataset such as ``candle``.
|
152 |
+
Defaults to ``"candle"``.
|
153 |
+
train_batch_size (int, optional): Training batch size.
|
154 |
+
Defaults to ``32``.
|
155 |
+
eval_batch_size (int, optional): Test batch size.
|
156 |
+
Defaults to ``32``.
|
157 |
+
num_workers (int, optional): Number of workers.
|
158 |
+
Defaults to ``8``.
|
159 |
+
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
|
160 |
+
Defaults to ``TaskType.SEGMENTATION``.
|
161 |
+
image_size (tuple[int, int], optional): Size to which input images should be resized.
|
162 |
+
Defaults to ``None``.
|
163 |
+
transform (Transform, optional): Transforms that should be applied to the input images.
|
164 |
+
Defaults to ``None``.
|
165 |
+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
|
166 |
+
Defaults to ``None``.
|
167 |
+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
|
168 |
+
Defaults to ``None``.
|
169 |
+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
|
170 |
+
Defaults to ``TestSplitMode.FROM_DIR``.
|
171 |
+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
|
172 |
+
Defaults to ``0.2``.
|
173 |
+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
|
174 |
+
Defaults to ``ValSplitMode.SAME_AS_TEST``.
|
175 |
+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
|
176 |
+
Defatuls to ``0.5``.
|
177 |
+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
|
178 |
+
Defaults to ``None``.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
root: Path | str = "./datasets/visa",
|
184 |
+
category: str = "capsules",
|
185 |
+
train_batch_size: int = 32,
|
186 |
+
eval_batch_size: int = 32,
|
187 |
+
num_workers: int = 8,
|
188 |
+
task: TaskType | str = TaskType.SEGMENTATION,
|
189 |
+
image_size: tuple[int, int] | None = None,
|
190 |
+
transform: Transform | None = None,
|
191 |
+
train_transform: Transform | None = None,
|
192 |
+
eval_transform: Transform | None = None,
|
193 |
+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
|
194 |
+
test_split_ratio: float = 0.2,
|
195 |
+
val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST,
|
196 |
+
val_split_ratio: float = 0.5,
|
197 |
+
seed: int | None = None,
|
198 |
+
) -> None:
|
199 |
+
super().__init__(
|
200 |
+
train_batch_size=train_batch_size,
|
201 |
+
eval_batch_size=eval_batch_size,
|
202 |
+
num_workers=num_workers,
|
203 |
+
image_size=image_size,
|
204 |
+
transform=transform,
|
205 |
+
train_transform=train_transform,
|
206 |
+
eval_transform=eval_transform,
|
207 |
+
test_split_mode=test_split_mode,
|
208 |
+
test_split_ratio=test_split_ratio,
|
209 |
+
val_split_mode=val_split_mode,
|
210 |
+
val_split_ratio=val_split_ratio,
|
211 |
+
seed=seed,
|
212 |
+
)
|
213 |
+
|
214 |
+
self.task = TaskType(task)
|
215 |
+
self.root = Path(root)
|
216 |
+
self.split_root = self.root / "visa_pytorch"
|
217 |
+
self.category = category
|
218 |
+
|
219 |
+
def _setup(self, _stage: str | None = None) -> None:
|
220 |
+
self.train_data = VisaDataset(
|
221 |
+
task=self.task,
|
222 |
+
transform=self.train_transform,
|
223 |
+
split=Split.TRAIN,
|
224 |
+
root=self.split_root,
|
225 |
+
category=self.category,
|
226 |
+
)
|
227 |
+
self.test_data = VisaDataset(
|
228 |
+
task=self.task,
|
229 |
+
transform=self.eval_transform,
|
230 |
+
split=Split.TEST,
|
231 |
+
root=self.split_root,
|
232 |
+
category=self.category,
|
233 |
+
)
|
234 |
+
|
235 |
+
def prepare_data(self) -> None:
|
236 |
+
"""Download the dataset if not available.
|
237 |
+
|
238 |
+
This method checks if the specified dataset is available in the file system.
|
239 |
+
If not, it downloads and extracts the dataset into the appropriate directory.
|
240 |
+
|
241 |
+
Example:
|
242 |
+
Assume the dataset is not available on the file system.
|
243 |
+
Here's how the directory structure looks before and after calling the
|
244 |
+
`prepare_data` method:
|
245 |
+
|
246 |
+
Before:
|
247 |
+
|
248 |
+
.. code-block:: bash
|
249 |
+
|
250 |
+
$ tree datasets
|
251 |
+
datasets
|
252 |
+
├── dataset1
|
253 |
+
└── dataset2
|
254 |
+
|
255 |
+
Calling the method:
|
256 |
+
|
257 |
+
.. code-block:: python
|
258 |
+
|
259 |
+
>> datamodule = Visa()
|
260 |
+
>> datamodule.prepare_data()
|
261 |
+
|
262 |
+
After:
|
263 |
+
|
264 |
+
.. code-block:: bash
|
265 |
+
|
266 |
+
$ tree datasets
|
267 |
+
datasets
|
268 |
+
├── dataset1
|
269 |
+
├── dataset2
|
270 |
+
└── visa
|
271 |
+
├── candle
|
272 |
+
├── ...
|
273 |
+
├── pipe_fryum
|
274 |
+
│ ├── Data
|
275 |
+
│ └── image_anno.csv
|
276 |
+
├── split_csv
|
277 |
+
│ ├── 1cls.csv
|
278 |
+
│ ├── 2cls_fewshot.csv
|
279 |
+
│ └── 2cls_highshot.csv
|
280 |
+
├── VisA_20220922.tar
|
281 |
+
└── visa_pytorch
|
282 |
+
├── candle
|
283 |
+
├── ...
|
284 |
+
├── pcb4
|
285 |
+
└── pipe_fryum
|
286 |
+
|
287 |
+
``prepare_data`` ensures that the dataset is converted to MVTec
|
288 |
+
format. ``visa_pytorch`` is the directory that contains the dataset
|
289 |
+
in the MVTec format. ``visa`` is the directory that contains the
|
290 |
+
original dataset.
|
291 |
+
"""
|
292 |
+
if (self.split_root / self.category).is_dir():
|
293 |
+
# dataset is available, and split has been applied
|
294 |
+
logger.info("Found the dataset and train/test split.")
|
295 |
+
elif (self.root / self.category).is_dir():
|
296 |
+
# dataset is available, but split has not yet been applied
|
297 |
+
logger.info("Found the dataset. Applying train/test split.")
|
298 |
+
self.apply_cls1_split()
|
299 |
+
else:
|
300 |
+
# dataset is not available
|
301 |
+
download_and_extract(self.root, DOWNLOAD_INFO)
|
302 |
+
logger.info("Downloaded the dataset. Applying train/test split.")
|
303 |
+
self.apply_cls1_split()
|
304 |
+
|
305 |
+
def apply_cls1_split(self) -> None:
|
306 |
+
"""Apply the 1-class subset splitting using the fixed split in the csv file.
|
307 |
+
|
308 |
+
adapted from https://github.com/amazon-science/spot-diff
|
309 |
+
"""
|
310 |
+
logger.info("preparing data")
|
311 |
+
categories = [
|
312 |
+
"candle",
|
313 |
+
"capsules",
|
314 |
+
"cashew",
|
315 |
+
"chewinggum",
|
316 |
+
"fryum",
|
317 |
+
"macaroni1",
|
318 |
+
"macaroni2",
|
319 |
+
"pcb1",
|
320 |
+
"pcb2",
|
321 |
+
"pcb3",
|
322 |
+
"pcb4",
|
323 |
+
"pipe_fryum",
|
324 |
+
]
|
325 |
+
|
326 |
+
split_file = self.root / "split_csv" / "1cls.csv"
|
327 |
+
|
328 |
+
for category in categories:
|
329 |
+
train_folder = self.split_root / category / "train"
|
330 |
+
test_folder = self.split_root / category / "test"
|
331 |
+
mask_folder = self.split_root / category / "ground_truth"
|
332 |
+
|
333 |
+
train_img_good_folder = train_folder / "good"
|
334 |
+
test_img_good_folder = test_folder / "good"
|
335 |
+
test_img_bad_folder = test_folder / "bad"
|
336 |
+
test_mask_bad_folder = mask_folder / "bad"
|
337 |
+
|
338 |
+
train_img_good_folder.mkdir(parents=True, exist_ok=True)
|
339 |
+
test_img_good_folder.mkdir(parents=True, exist_ok=True)
|
340 |
+
test_img_bad_folder.mkdir(parents=True, exist_ok=True)
|
341 |
+
test_mask_bad_folder.mkdir(parents=True, exist_ok=True)
|
342 |
+
|
343 |
+
with split_file.open(encoding="utf-8") as file:
|
344 |
+
csvreader = csv.reader(file)
|
345 |
+
next(csvreader)
|
346 |
+
for row in csvreader:
|
347 |
+
category, split, label, image_path, mask_path = row
|
348 |
+
label = "good" if label == "normal" else "bad"
|
349 |
+
image_name = image_path.split("/")[-1]
|
350 |
+
mask_name = mask_path.split("/")[-1]
|
351 |
+
|
352 |
+
img_src_path = self.root / image_path
|
353 |
+
msk_src_path = self.root / mask_path
|
354 |
+
img_dst_path = self.split_root / category / split / label / image_name
|
355 |
+
msk_dst_path = self.split_root / category / "ground_truth" / label / mask_name
|
356 |
+
|
357 |
+
shutil.copyfile(img_src_path, img_dst_path)
|
358 |
+
if split == "test" and label == "bad":
|
359 |
+
mask = cv2.imread(str(msk_src_path))
|
360 |
+
|
361 |
+
# binarize mask
|
362 |
+
mask[mask != 0] = 255
|
363 |
+
|
364 |
+
cv2.imwrite(str(msk_dst_path), mask)
|
anomalib/data/predict.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inference Dataset."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
from torchvision.transforms.v2 import Transform
|
12 |
+
|
13 |
+
from anomalib.data.utils import get_image_filenames, read_image
|
14 |
+
|
15 |
+
|
16 |
+
class PredictDataset(Dataset):
|
17 |
+
"""Inference Dataset to perform prediction.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str | Path): Path to an image or image-folder.
|
21 |
+
transform (A.Compose | None, optional): Transform object describing the transforms that are
|
22 |
+
applied to the inputs.
|
23 |
+
image_size (int | tuple[int, int] | None, optional): Target image size
|
24 |
+
to resize the original image. Defaults to None.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
path: str | Path,
|
30 |
+
transform: Transform | None = None,
|
31 |
+
image_size: int | tuple[int, int] = (256, 256),
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.image_filenames = get_image_filenames(path)
|
36 |
+
self.transform = transform
|
37 |
+
self.image_size = image_size
|
38 |
+
|
39 |
+
def __len__(self) -> int:
|
40 |
+
"""Get the number of images in the given path."""
|
41 |
+
return len(self.image_filenames)
|
42 |
+
|
43 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
44 |
+
"""Get the image based on the `index`."""
|
45 |
+
image_filename = self.image_filenames[index]
|
46 |
+
image = read_image(image_filename, as_tensor=True)
|
47 |
+
if self.transform:
|
48 |
+
image = self.transform(image)
|
49 |
+
pre_processed = {"image": image}
|
50 |
+
pre_processed["image_path"] = str(image_filename)
|
51 |
+
|
52 |
+
return pre_processed
|
anomalib/data/transforms/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom input transforms for Anomalib."""
|
2 |
+
|
3 |
+
# Copyright (C) 2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from .center_crop import ExportableCenterCrop
|
7 |
+
|
8 |
+
__all__ = ["ExportableCenterCrop"]
|
anomalib/data/transforms/center_crop.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Torchvision transforms for Anomalib."""
|
2 |
+
|
3 |
+
# Original Code
|
4 |
+
# Copyright (c) Soumith Chintala 2016
|
5 |
+
# https://github.com/pytorch/vision/blob/v0.16.1/torchvision/transforms/v2/functional/_geometry.py
|
6 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
7 |
+
#
|
8 |
+
# Modified
|
9 |
+
# Copyright (C) 2024 Intel Corporation
|
10 |
+
# SPDX-License-Identifier: Apache-2.0
|
11 |
+
|
12 |
+
from typing import Any
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch.nn.functional import pad
|
16 |
+
from torchvision.transforms.v2 import Transform
|
17 |
+
from torchvision.transforms.v2.functional._geometry import (
|
18 |
+
_center_crop_compute_padding,
|
19 |
+
_center_crop_parse_output_size,
|
20 |
+
_parse_pad_padding,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def _center_crop_compute_crop_anchor(
|
25 |
+
crop_height: int,
|
26 |
+
crop_width: int,
|
27 |
+
image_height: int,
|
28 |
+
image_width: int,
|
29 |
+
) -> tuple[int, int]:
|
30 |
+
"""Compute the anchor point for center-cropping.
|
31 |
+
|
32 |
+
This function is a modified version of the torchvision.transforms.functional._center_crop_compute_crop_anchor
|
33 |
+
function. The original function uses `round` to compute the anchor point, which is not compatible with ONNX.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
crop_height (int): Desired height of the crop.
|
37 |
+
crop_width (int): Desired width of the crop.
|
38 |
+
image_height (int): Height of the input image.
|
39 |
+
image_width (int): Width of the input image.
|
40 |
+
"""
|
41 |
+
crop_top = torch.tensor((image_height - crop_height) / 2.0).round().int().item()
|
42 |
+
crop_left = torch.tensor((image_width - crop_width) / 2.0).round().int().item()
|
43 |
+
return crop_top, crop_left
|
44 |
+
|
45 |
+
|
46 |
+
def center_crop_image(image: torch.Tensor, output_size: list[int]) -> torch.Tensor:
|
47 |
+
"""Apply center-cropping to an input image.
|
48 |
+
|
49 |
+
Uses the modified anchor point computation function to compute the anchor point for center-cropping.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image (torch.Tensor): Input image to be center-cropped.
|
53 |
+
output_size (list[int]): Desired output size of the crop.
|
54 |
+
"""
|
55 |
+
crop_height, crop_width = _center_crop_parse_output_size(output_size)
|
56 |
+
shape = image.shape
|
57 |
+
if image.numel() == 0:
|
58 |
+
return image.reshape(shape[:-2] + (crop_height, crop_width))
|
59 |
+
image_height, image_width = shape[-2:]
|
60 |
+
|
61 |
+
if crop_height > image_height or crop_width > image_width:
|
62 |
+
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
|
63 |
+
image = pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
|
64 |
+
|
65 |
+
image_height, image_width = image.shape[-2:]
|
66 |
+
if crop_width == image_width and crop_height == image_height:
|
67 |
+
return image
|
68 |
+
|
69 |
+
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
|
70 |
+
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
|
71 |
+
|
72 |
+
|
73 |
+
class ExportableCenterCrop(Transform):
|
74 |
+
"""Transform that applies center-cropping to an input image and allows to be exported to ONNX.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
size (int | tuple[int, int]): Desired output size of the crop.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, size: int | tuple[int, int]) -> None:
|
81 |
+
super().__init__()
|
82 |
+
self.size = list(size) if isinstance(size, tuple) else [size, size]
|
83 |
+
|
84 |
+
def _transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor:
|
85 |
+
"""Apply the transform."""
|
86 |
+
del params
|
87 |
+
return center_crop_image(inpt, output_size=self.size)
|
anomalib/data/utils/__init__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper utilities for data."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from .augmenter import Augmenter
|
7 |
+
from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
|
8 |
+
from .download import DownloadInfo, download_and_extract
|
9 |
+
from .generators import random_2d_perlin
|
10 |
+
from .image import (
|
11 |
+
generate_output_image_filename,
|
12 |
+
get_image_filenames,
|
13 |
+
get_image_height_and_width,
|
14 |
+
read_depth_image,
|
15 |
+
read_image,
|
16 |
+
read_mask,
|
17 |
+
)
|
18 |
+
from .label import LabelName
|
19 |
+
from .path import (
|
20 |
+
DirType,
|
21 |
+
_check_and_convert_path,
|
22 |
+
_prepare_files_labels,
|
23 |
+
resolve_path,
|
24 |
+
validate_and_resolve_path,
|
25 |
+
validate_path,
|
26 |
+
)
|
27 |
+
from .split import Split, TestSplitMode, ValSplitMode, concatenate_datasets, random_split, split_by_label
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
"generate_output_image_filename",
|
31 |
+
"get_image_filenames",
|
32 |
+
"get_image_height_and_width",
|
33 |
+
"random_2d_perlin",
|
34 |
+
"read_image",
|
35 |
+
"read_mask",
|
36 |
+
"read_depth_image",
|
37 |
+
"random_split",
|
38 |
+
"split_by_label",
|
39 |
+
"concatenate_datasets",
|
40 |
+
"Split",
|
41 |
+
"ValSplitMode",
|
42 |
+
"TestSplitMode",
|
43 |
+
"LabelName",
|
44 |
+
"DirType",
|
45 |
+
"Augmenter",
|
46 |
+
"masks_to_boxes",
|
47 |
+
"boxes_to_masks",
|
48 |
+
"boxes_to_anomaly_maps",
|
49 |
+
"download_and_extract",
|
50 |
+
"DownloadInfo",
|
51 |
+
"_check_and_convert_path",
|
52 |
+
"_prepare_files_labels",
|
53 |
+
"resolve_path",
|
54 |
+
"validate_path",
|
55 |
+
"validate_and_resolve_path",
|
56 |
+
]
|
anomalib/data/utils/augmenter.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation."""
|
2 |
+
|
3 |
+
# Original Code
|
4 |
+
# Copyright (c) 2021 VitjanZ
|
5 |
+
# https://github.com/VitjanZ/DRAEM.
|
6 |
+
# SPDX-License-Identifier: MIT
|
7 |
+
#
|
8 |
+
# Modified
|
9 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
10 |
+
# SPDX-License-Identifier: Apache-2.0
|
11 |
+
|
12 |
+
|
13 |
+
import math
|
14 |
+
import random
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import imgaug.augmenters as iaa
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
from torchvision.datasets.folder import IMG_EXTENSIONS
|
23 |
+
|
24 |
+
from anomalib.data.utils.generators.perlin import random_2d_perlin
|
25 |
+
|
26 |
+
|
27 |
+
def nextpow2(value: int) -> int:
|
28 |
+
"""Return the smallest power of 2 greater than or equal to the input value."""
|
29 |
+
return 2 ** (math.ceil(math.log(value, 2)))
|
30 |
+
|
31 |
+
|
32 |
+
class Augmenter:
|
33 |
+
"""Class that generates noisy augmentations of input images.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
anomaly_source_path (str | None): Path to a folder of images that will be used as source of the anomalous
|
37 |
+
noise. If not specified, random noise will be used instead.
|
38 |
+
p_anomalous (float): Probability that the anomalous perturbation will be applied to a given image.
|
39 |
+
beta (float): Parameter that determines the opacity of the noise mask.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
anomaly_source_path: str | None = None,
|
45 |
+
p_anomalous: float = 0.5,
|
46 |
+
beta: float | tuple[float, float] = (0.2, 1.0),
|
47 |
+
) -> None:
|
48 |
+
self.p_anomalous = p_anomalous
|
49 |
+
self.beta = beta
|
50 |
+
|
51 |
+
self.anomaly_source_paths: list[Path] = []
|
52 |
+
if anomaly_source_path is not None:
|
53 |
+
for img_ext in IMG_EXTENSIONS:
|
54 |
+
self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext))
|
55 |
+
|
56 |
+
self.augmenters = [
|
57 |
+
iaa.GammaContrast((0.5, 2.0), per_channel=True),
|
58 |
+
iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)),
|
59 |
+
iaa.pillike.EnhanceSharpness(),
|
60 |
+
iaa.AddToHueAndSaturation((-50, 50), per_channel=True),
|
61 |
+
iaa.Solarize(0.5, threshold=(32, 128)),
|
62 |
+
iaa.Posterize(),
|
63 |
+
iaa.Invert(),
|
64 |
+
iaa.pillike.Autocontrast(),
|
65 |
+
iaa.pillike.Equalize(),
|
66 |
+
iaa.Affine(rotate=(-45, 45)),
|
67 |
+
]
|
68 |
+
self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
|
69 |
+
|
70 |
+
def rand_augmenter(self) -> iaa.Sequential:
|
71 |
+
"""Select 3 random transforms that will be applied to the anomaly source images.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
A selection of 3 transforms.
|
75 |
+
"""
|
76 |
+
aug_ind = np.random.default_rng().choice(np.arange(len(self.augmenters)), 3, replace=False)
|
77 |
+
return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]])
|
78 |
+
|
79 |
+
def generate_perturbation(
|
80 |
+
self,
|
81 |
+
height: int,
|
82 |
+
width: int,
|
83 |
+
anomaly_source_path: Path | str | None = None,
|
84 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
85 |
+
"""Generate an image containing a random anomalous perturbation using a source image.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
height (int): height of the generated image.
|
89 |
+
width: (int): width of the generated image.
|
90 |
+
anomaly_source_path (Path | str | None): Path to an image file. If not provided, random noise will be used
|
91 |
+
instead.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask.
|
95 |
+
"""
|
96 |
+
# Generate random perlin noise
|
97 |
+
perlin_scale = 6
|
98 |
+
min_perlin_scale = 0
|
99 |
+
|
100 |
+
perlin_scalex = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale)
|
101 |
+
perlin_scaley = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale)
|
102 |
+
|
103 |
+
perlin_noise = random_2d_perlin((nextpow2(height), nextpow2(width)), (perlin_scalex, perlin_scaley))[
|
104 |
+
:height,
|
105 |
+
:width,
|
106 |
+
]
|
107 |
+
perlin_noise = self.rot(image=perlin_noise)
|
108 |
+
|
109 |
+
# Create mask from perlin noise
|
110 |
+
mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
|
111 |
+
mask = np.expand_dims(mask, axis=2).astype(np.float32)
|
112 |
+
|
113 |
+
# Load anomaly source image
|
114 |
+
if anomaly_source_path:
|
115 |
+
anomaly_source_img = np.array(Image.open(anomaly_source_path))
|
116 |
+
anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height))
|
117 |
+
else: # if no anomaly source is specified, we use the perlin noise as anomalous source
|
118 |
+
anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2)
|
119 |
+
anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8)
|
120 |
+
|
121 |
+
# Augment anomaly source image
|
122 |
+
aug = self.rand_augmenter()
|
123 |
+
anomaly_img_augmented = aug(image=anomaly_source_img)
|
124 |
+
|
125 |
+
# Create anomalous perturbation that we will apply to the image
|
126 |
+
perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0
|
127 |
+
|
128 |
+
return perturbation, mask
|
129 |
+
|
130 |
+
def augment_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
131 |
+
"""Generate anomalous augmentations for a batch of input images.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
batch (torch.Tensor): Batch of input images
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- Augmented image to which anomalous perturbations have been added.
|
138 |
+
- Ground truth masks corresponding to the anomalous perturbations.
|
139 |
+
"""
|
140 |
+
batch_size, channels, height, width = batch.shape
|
141 |
+
|
142 |
+
# Collect perturbations
|
143 |
+
perturbations_list = []
|
144 |
+
masks_list = []
|
145 |
+
for _ in range(batch_size):
|
146 |
+
if torch.rand(1) > self.p_anomalous: # include normal samples
|
147 |
+
perturbations_list.append(torch.zeros((channels, height, width)))
|
148 |
+
masks_list.append(torch.zeros((1, height, width)))
|
149 |
+
else:
|
150 |
+
anomaly_source_path = (
|
151 |
+
random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None
|
152 |
+
)
|
153 |
+
perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path)
|
154 |
+
perturbations_list.append(torch.Tensor(perturbation).permute((2, 0, 1)))
|
155 |
+
masks_list.append(torch.Tensor(mask).permute((2, 0, 1)))
|
156 |
+
|
157 |
+
perturbations = torch.stack(perturbations_list).to(batch.device)
|
158 |
+
masks = torch.stack(masks_list).to(batch.device)
|
159 |
+
|
160 |
+
# Apply perturbations batch wise
|
161 |
+
if isinstance(self.beta, float):
|
162 |
+
beta = self.beta
|
163 |
+
elif isinstance(self.beta, tuple):
|
164 |
+
beta = torch.rand(batch_size) * (self.beta[1] - self.beta[0]) + self.beta[0]
|
165 |
+
beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) # type: ignore[attr-defined]
|
166 |
+
else:
|
167 |
+
msg = "Beta must be either float or tuple of floats"
|
168 |
+
raise TypeError(msg)
|
169 |
+
|
170 |
+
augmented_batch = batch * (1 - masks) + (beta) * perturbations + (1 - beta) * batch * (masks)
|
171 |
+
|
172 |
+
return augmented_batch, masks
|
anomalib/data/utils/boxes.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions for processing bounding box detections and annotations."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from anomalib.utils.cv import connected_components_cpu, connected_components_gpu
|
10 |
+
|
11 |
+
|
12 |
+
def masks_to_boxes(
|
13 |
+
masks: torch.Tensor,
|
14 |
+
anomaly_maps: torch.Tensor | None = None,
|
15 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
16 |
+
"""Convert a batch of segmentation masks to bounding box coordinates.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
masks (torch.Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W)
|
20 |
+
anomaly_maps (Tensor | None, optional): Anomaly maps of shape (B, 1, H, W), (B, H, W) or (H, W) which are
|
21 |
+
used to determine an anomaly score for the converted bounding boxes.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
list[torch.Tensor]: A list of length B where each element is a tensor of shape (N, 4)
|
25 |
+
containing the bounding box coordinates of the objects in the masks in xyxy format.
|
26 |
+
list[torch.Tensor]: A list of length B where each element is a tensor of length (N)
|
27 |
+
containing an anomaly score for each of the converted boxes.
|
28 |
+
"""
|
29 |
+
height, width = masks.shape[-2:]
|
30 |
+
masks = masks.view((-1, 1, height, width)).float() # reshape to (B, 1, H, W) and cast to float
|
31 |
+
if anomaly_maps is not None:
|
32 |
+
anomaly_maps = anomaly_maps.view((-1,) + masks.shape[-2:])
|
33 |
+
|
34 |
+
if masks.is_cpu:
|
35 |
+
batch_comps = connected_components_cpu(masks).squeeze(1)
|
36 |
+
else:
|
37 |
+
batch_comps = connected_components_gpu(masks).squeeze(1)
|
38 |
+
|
39 |
+
batch_boxes = []
|
40 |
+
batch_scores = []
|
41 |
+
for im_idx, im_comps in enumerate(batch_comps):
|
42 |
+
labels = torch.unique(im_comps)
|
43 |
+
im_boxes = []
|
44 |
+
im_scores = []
|
45 |
+
for label in labels[labels != 0]:
|
46 |
+
y_loc, x_loc = torch.where(im_comps == label)
|
47 |
+
# add box
|
48 |
+
box = torch.Tensor([torch.min(x_loc), torch.min(y_loc), torch.max(x_loc), torch.max(y_loc)]).to(
|
49 |
+
masks.device,
|
50 |
+
)
|
51 |
+
im_boxes.append(box)
|
52 |
+
if anomaly_maps is not None:
|
53 |
+
im_scores.append(torch.max(anomaly_maps[im_idx, y_loc, x_loc]))
|
54 |
+
batch_boxes.append(torch.stack(im_boxes) if im_boxes else torch.empty((0, 4), device=masks.device))
|
55 |
+
batch_scores.append(torch.stack(im_scores) if im_scores else torch.empty(0, device=masks.device))
|
56 |
+
|
57 |
+
return batch_boxes, batch_scores
|
58 |
+
|
59 |
+
|
60 |
+
def boxes_to_masks(boxes: list[torch.Tensor], image_size: tuple[int, int]) -> torch.Tensor:
|
61 |
+
"""Convert bounding boxes to segmentations masks.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4)
|
65 |
+
containing the bounding box coordinates of the regions of interest in xyxy format.
|
66 |
+
image_size (tuple[int, int]): Image size of the output masks in (H, W) format.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Tensor: torch.Tensor of shape (B, H, W) in which each slice is a binary mask showing the pixels contained by a
|
70 |
+
bounding box.
|
71 |
+
"""
|
72 |
+
masks = torch.zeros((len(boxes), *image_size)).to(boxes[0].device)
|
73 |
+
for im_idx, im_boxes in enumerate(boxes):
|
74 |
+
for box in im_boxes:
|
75 |
+
x_1, y_1, x_2, y_2 = box.int()
|
76 |
+
masks[im_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = 1
|
77 |
+
return masks
|
78 |
+
|
79 |
+
|
80 |
+
def boxes_to_anomaly_maps(boxes: torch.Tensor, scores: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
81 |
+
"""Convert bounding box coordinates to anomaly heatmaps.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4)
|
85 |
+
containing the bounding box coordinates of the regions of interest in xyxy format.
|
86 |
+
scores (list[torch.Tensor]): A list of length B where each element is a 1D tensor of length N
|
87 |
+
containing the anomaly scores for each region of interest.
|
88 |
+
image_size (tuple[int, int]): Image size of the output masks in (H, W) format.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor: torch.Tensor of shape (B, H, W). The pixel locations within each bounding box are collectively
|
92 |
+
assigned the anomaly score of the bounding box. In the case of overlapping bounding boxes,
|
93 |
+
the highest score is used.
|
94 |
+
"""
|
95 |
+
anomaly_maps = torch.zeros((len(boxes), *image_size)).to(boxes[0].device)
|
96 |
+
for im_idx, (im_boxes, im_scores) in enumerate(zip(boxes, scores, strict=False)):
|
97 |
+
im_map = torch.zeros((im_boxes.shape[0], *image_size))
|
98 |
+
for box_idx, (box, score) in enumerate(zip(im_boxes, im_scores, strict=True)):
|
99 |
+
x_1, y_1, x_2, y_2 = box.int()
|
100 |
+
im_map[box_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = score
|
101 |
+
anomaly_maps[im_idx], _ = im_map.max(dim=0)
|
102 |
+
return anomaly_maps
|
103 |
+
|
104 |
+
|
105 |
+
def scale_boxes(boxes: torch.Tensor, image_size: torch.Size, new_size: torch.Size) -> torch.Tensor:
|
106 |
+
"""Scale bbox coordinates to a new image size.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
boxes (torch.Tensor): Boxes of shape (N, 4) - (x1, y1, x2, y2).
|
110 |
+
image_size (Size): Size of the original image in which the bbox coordinates were retrieved.
|
111 |
+
new_size (Size): New image size to which the bbox coordinates will be scaled.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: Updated boxes of shape (N, 4) - (x1, y1, x2, y2).
|
115 |
+
"""
|
116 |
+
scale = torch.Tensor([*new_size]) / torch.Tensor([*image_size])
|
117 |
+
return boxes * scale.repeat(2).to(boxes.device)
|
anomalib/data/utils/download.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper to show progress bars with `urlretrieve`, check hash of file."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022-2024 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
|
7 |
+
import hashlib
|
8 |
+
import io
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import re
|
12 |
+
import tarfile
|
13 |
+
from collections.abc import Iterable
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from pathlib import Path
|
16 |
+
from tarfile import TarFile, TarInfo
|
17 |
+
from urllib.request import urlretrieve
|
18 |
+
from zipfile import ZipFile
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class DownloadInfo:
|
27 |
+
"""Info needed to download a dataset from a url."""
|
28 |
+
|
29 |
+
name: str
|
30 |
+
url: str
|
31 |
+
hashsum: str
|
32 |
+
filename: str | None = None
|
33 |
+
|
34 |
+
|
35 |
+
class DownloadProgressBar(tqdm):
|
36 |
+
"""Create progress bar for urlretrieve. Subclasses `tqdm`.
|
37 |
+
|
38 |
+
For information about the parameters in constructor, refer to `tqdm`'s documentation.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
iterable (Iterable | None): Iterable to decorate with a progressbar.
|
42 |
+
Leave blank to manually manage the updates.
|
43 |
+
desc (str | None): Prefix for the progressbar.
|
44 |
+
total (int | float | None): The number of expected iterations. If unspecified,
|
45 |
+
len(iterable) is used if possible. If float("inf") or as a last
|
46 |
+
resort, only basic progress statistics are displayed
|
47 |
+
(no ETA, no progressbar).
|
48 |
+
If `gui` is True and this parameter needs subsequent updating,
|
49 |
+
specify an initial arbitrary large positive number,
|
50 |
+
e.g. 9e9.
|
51 |
+
leave (bool | None): upon termination of iteration. If `None`, will leave only if `position` is `0`.
|
52 |
+
file (io.TextIOWrapper | io.StringIO | None): Specifies where to output the progress messages
|
53 |
+
(default: sys.stderr). Uses `file.write(str)` and
|
54 |
+
`file.flush()` methods. For encoding, see
|
55 |
+
`write_bytes`.
|
56 |
+
ncols (int | None): The width of the entire output message. If specified,
|
57 |
+
dynamically resizes the progressbar to stay within this bound.
|
58 |
+
If unspecified, attempts to use environment width. The
|
59 |
+
fallback is a meter width of 10 and no limit for the counter and
|
60 |
+
statistics. If 0, will not print any meter (only stats).
|
61 |
+
mininterval (float | None): Minimum progress display update interval [default: 0.1] seconds.
|
62 |
+
maxinterval (float | None): Maximum progress display update interval [default: 10] seconds.
|
63 |
+
Automatically adjusts `miniters` to correspond to `mininterval`
|
64 |
+
after long display update lag. Only works if `dynamic_miniters`
|
65 |
+
or monitor thread is enabled.
|
66 |
+
miniters (int | float | None): Minimum progress display update interval, in iterations.
|
67 |
+
If 0 and `dynamic_miniters`, will automatically adjust to equal
|
68 |
+
`mininterval` (more CPU efficient, good for tight loops).
|
69 |
+
If > 0, will skip display of specified number of iterations.
|
70 |
+
Tweak this and `mininterval` to get very efficient loops.
|
71 |
+
If your progress is erratic with both fast and slow iterations
|
72 |
+
(network, skipping items, etc) you should set miniters=1.
|
73 |
+
use_ascii (str | bool | None): If unspecified or False, use unicode (smooth blocks) to fill
|
74 |
+
the meter. The fallback is to use ASCII characters " 123456789#".
|
75 |
+
disable (bool | None): Whether to disable the entire progressbar wrapper
|
76 |
+
[default: False]. If set to None, disable on non-TTY.
|
77 |
+
unit (str | None): String that will be used to define the unit of each iteration
|
78 |
+
[default: it].
|
79 |
+
unit_scale (int | float | bool): If 1 or True, the number of iterations will be reduced/scaled
|
80 |
+
automatically and a metric prefix following the
|
81 |
+
International System of Units standard will be added
|
82 |
+
(kilo, mega, etc.) [default: False]. If any other non-zero
|
83 |
+
number, will scale `total` and `n`.
|
84 |
+
dynamic_ncols (bool | None): If set, constantly alters `ncols` and `nrows` to the
|
85 |
+
environment (allowing for window resizes) [default: False].
|
86 |
+
smoothing (float | None): Exponential moving average smoothing factor for speed estimates
|
87 |
+
(ignored in GUI mode). Ranges from 0 (average speed) to 1
|
88 |
+
(current/instantaneous speed) [default: 0.3].
|
89 |
+
bar_format (str | None): Specify a custom bar string formatting. May impact performance.
|
90 |
+
[default: '{l_bar}{bar}{r_bar}'], where
|
91 |
+
l_bar='{desc}: {percentage:3.0f}%|' and
|
92 |
+
r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
|
93 |
+
'{rate_fmt}{postfix}]'
|
94 |
+
Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
|
95 |
+
percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
|
96 |
+
rate, rate_fmt, rate_noinv, rate_noinv_fmt,
|
97 |
+
rate_inv, rate_inv_fmt, postfix, unit_divisor,
|
98 |
+
remaining, remaining_s, eta.
|
99 |
+
Note that a trailing ": " is automatically removed after {desc}
|
100 |
+
if the latter is empty.
|
101 |
+
initial (int | float | None): The initial counter value. Useful when restarting a progress
|
102 |
+
bar [default: 0]. If using float, consider specifying `{n:.3f}`
|
103 |
+
or similar in `bar_format`, or specifying `unit_scale`.
|
104 |
+
position (int | None): Specify the line offset to print this bar (starting from 0)
|
105 |
+
Automatic if unspecified.
|
106 |
+
Useful to manage multiple bars at once (eg, from threads).
|
107 |
+
postfix (dict | None): Specify additional stats to display at the end of the bar.
|
108 |
+
Calls `set_postfix(**postfix)` if possible (dict).
|
109 |
+
unit_divisor (float | None): [default: 1000], ignored unless `unit_scale` is True.
|
110 |
+
write_bytes (bool | None): If (default: None) and `file` is unspecified,
|
111 |
+
bytes will be written in Python 2. If `True` will also write
|
112 |
+
bytes. In all other cases will default to unicode.
|
113 |
+
lock_args (tuple | None): Passed to `refresh` for intermediate output
|
114 |
+
(initialisation, iterating, and updating).
|
115 |
+
nrows (int | None): The screen height. If specified, hides nested bars
|
116 |
+
outside this bound. If unspecified, attempts to use environment height.
|
117 |
+
The fallback is 20.
|
118 |
+
colour (str | None): Bar colour (e.g. 'green', '#00ff00').
|
119 |
+
delay (float | None): Don't display until [default: 0] seconds have elapsed.
|
120 |
+
gui (bool | None): WARNING: internal parameter - do not use.
|
121 |
+
Use tqdm.gui.tqdm(...) instead. If set, will attempt to use
|
122 |
+
matplotlib animations for a graphical output [default: False].
|
123 |
+
|
124 |
+
|
125 |
+
Example:
|
126 |
+
>>> with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as p_bar:
|
127 |
+
>>> urllib.request.urlretrieve(url, filename=output_path, reporthook=p_bar.update_to)
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
iterable: Iterable | None = None,
|
133 |
+
desc: str | None = None,
|
134 |
+
total: int | float | None = None,
|
135 |
+
leave: bool | None = True,
|
136 |
+
file: io.TextIOWrapper | io.StringIO | None = None,
|
137 |
+
ncols: int | None = None,
|
138 |
+
mininterval: float | None = 0.1,
|
139 |
+
maxinterval: float | None = 10.0,
|
140 |
+
miniters: int | float | None = None,
|
141 |
+
use_ascii: bool | str | None = None,
|
142 |
+
disable: bool | None = False,
|
143 |
+
unit: str | None = "it",
|
144 |
+
unit_scale: bool | int | float | None = False,
|
145 |
+
dynamic_ncols: bool | None = False,
|
146 |
+
smoothing: float | None = 0.3,
|
147 |
+
bar_format: str | None = None,
|
148 |
+
initial: int | float | None = 0,
|
149 |
+
position: int | None = None,
|
150 |
+
postfix: dict | None = None,
|
151 |
+
unit_divisor: float | None = 1000,
|
152 |
+
write_bytes: bool | None = None,
|
153 |
+
lock_args: tuple | None = None,
|
154 |
+
nrows: int | None = None,
|
155 |
+
colour: str | None = None,
|
156 |
+
delay: float | None = 0,
|
157 |
+
gui: bool | None = False,
|
158 |
+
**kwargs,
|
159 |
+
) -> None:
|
160 |
+
super().__init__(
|
161 |
+
iterable=iterable,
|
162 |
+
desc=desc,
|
163 |
+
total=total,
|
164 |
+
leave=leave,
|
165 |
+
file=file,
|
166 |
+
ncols=ncols,
|
167 |
+
mininterval=mininterval,
|
168 |
+
maxinterval=maxinterval,
|
169 |
+
miniters=miniters,
|
170 |
+
ascii=use_ascii,
|
171 |
+
disable=disable,
|
172 |
+
unit=unit,
|
173 |
+
unit_scale=unit_scale,
|
174 |
+
dynamic_ncols=dynamic_ncols,
|
175 |
+
smoothing=smoothing,
|
176 |
+
bar_format=bar_format,
|
177 |
+
initial=initial,
|
178 |
+
position=position,
|
179 |
+
postfix=postfix,
|
180 |
+
unit_divisor=unit_divisor,
|
181 |
+
write_bytes=write_bytes,
|
182 |
+
lock_args=lock_args,
|
183 |
+
nrows=nrows,
|
184 |
+
colour=colour,
|
185 |
+
delay=delay,
|
186 |
+
gui=gui,
|
187 |
+
**kwargs,
|
188 |
+
)
|
189 |
+
self.total: int | float | None
|
190 |
+
|
191 |
+
def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size: int | None = None) -> None:
|
192 |
+
"""Progress bar hook for tqdm.
|
193 |
+
|
194 |
+
Based on https://stackoverflow.com/a/53877507
|
195 |
+
The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve.
|
196 |
+
However the context needs a few parameters. Refer to the example.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
chunk_number (int, optional): The current chunk being processed. Defaults to 1.
|
200 |
+
max_chunk_size (int, optional): Maximum size of each chunk. Defaults to 1.
|
201 |
+
total_size (int, optional): Total download size. Defaults to None.
|
202 |
+
"""
|
203 |
+
if total_size is not None:
|
204 |
+
self.total = total_size
|
205 |
+
self.update(chunk_number * max_chunk_size - self.n)
|
206 |
+
|
207 |
+
|
208 |
+
def is_file_potentially_dangerous(file_name: str) -> bool:
|
209 |
+
"""Check if a file is potentially dangerous.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
file_name (str): Filename.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
bool: True if the member is potentially dangerous, False otherwise.
|
216 |
+
|
217 |
+
"""
|
218 |
+
# Some example criteria. We could expand this.
|
219 |
+
unsafe_patterns = ["/etc/", "/root/"]
|
220 |
+
return any(re.search(pattern, file_name) for pattern in unsafe_patterns)
|
221 |
+
|
222 |
+
|
223 |
+
def safe_extract(tar_file: TarFile, root: Path, members: list[TarInfo]) -> None:
|
224 |
+
"""Extract safe members from a tar archive.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
tar_file (TarFile): TarFile object.
|
228 |
+
root (Path): Root directory where the dataset will be stored.
|
229 |
+
members (List[TarInfo]): List of safe members to be extracted.
|
230 |
+
|
231 |
+
"""
|
232 |
+
for member in members:
|
233 |
+
tar_file.extract(member, root)
|
234 |
+
|
235 |
+
|
236 |
+
def generate_hash(file_path: str | Path, algorithm: str = "sha256") -> str:
|
237 |
+
"""Generate a hash of a file using the specified algorithm.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
file_path (str | Path): Path to the file to hash.
|
241 |
+
algorithm (str): The hashing algorithm to use (e.g., 'sha256', 'sha3_512').
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
str: The hexadecimal hash string of the file.
|
245 |
+
|
246 |
+
Raises:
|
247 |
+
ValueError: If the specified hashing algorithm is not supported.
|
248 |
+
"""
|
249 |
+
# Get the hashing algorithm.
|
250 |
+
try:
|
251 |
+
hasher = getattr(hashlib, algorithm)()
|
252 |
+
except AttributeError as err:
|
253 |
+
msg = f"Unsupported hashing algorithm: {algorithm}"
|
254 |
+
raise ValueError(msg) from err
|
255 |
+
|
256 |
+
# Read the file in chunks to avoid loading it all into memory
|
257 |
+
with Path(file_path).open("rb") as file:
|
258 |
+
for chunk in iter(lambda: file.read(4096), b""):
|
259 |
+
hasher.update(chunk)
|
260 |
+
|
261 |
+
# Return the computed hash value in hexadecimal format
|
262 |
+
return hasher.hexdigest()
|
263 |
+
|
264 |
+
|
265 |
+
def check_hash(file_path: Path, expected_hash: str, algorithm: str = "sha256") -> None:
|
266 |
+
"""Raise value error if hash does not match the calculated hash of the file.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
file_path (Path): Path to file.
|
270 |
+
expected_hash (str): Expected hash of the file.
|
271 |
+
algorithm (str): Hashing algorithm to use ('sha256', 'sha3_512', etc.).
|
272 |
+
"""
|
273 |
+
# Compare the calculated hash with the expected hash
|
274 |
+
calculated_hash = generate_hash(file_path, algorithm)
|
275 |
+
if calculated_hash != expected_hash:
|
276 |
+
msg = (
|
277 |
+
f"Calculated hash {calculated_hash} of downloaded file {file_path} does not match the required hash "
|
278 |
+
f"{expected_hash}."
|
279 |
+
)
|
280 |
+
raise ValueError(msg)
|
281 |
+
|
282 |
+
|
283 |
+
def extract(file_name: Path, root: Path) -> None:
|
284 |
+
"""Extract a dataset.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
file_name (Path): Path of the file to be extracted.
|
288 |
+
root (Path): Root directory where the dataset will be stored.
|
289 |
+
|
290 |
+
"""
|
291 |
+
logger.info("Extracting dataset into root folder.")
|
292 |
+
|
293 |
+
# Safely extract zip files
|
294 |
+
if file_name.suffix == ".zip":
|
295 |
+
with ZipFile(file_name, "r") as zip_file:
|
296 |
+
for file_info in zip_file.infolist():
|
297 |
+
if not is_file_potentially_dangerous(file_info.filename):
|
298 |
+
zip_file.extract(file_info, root)
|
299 |
+
|
300 |
+
# Safely extract tar files.
|
301 |
+
elif file_name.suffix in (".tar", ".gz", ".xz", ".tgz"):
|
302 |
+
with tarfile.open(file_name) as tar_file:
|
303 |
+
members = tar_file.getmembers()
|
304 |
+
safe_members = [member for member in members if not is_file_potentially_dangerous(member.name)]
|
305 |
+
safe_extract(tar_file, root, safe_members)
|
306 |
+
|
307 |
+
else:
|
308 |
+
msg = f"Unrecognized file format: {file_name}"
|
309 |
+
raise ValueError(msg)
|
310 |
+
|
311 |
+
logger.info("Cleaning up files.")
|
312 |
+
file_name.unlink()
|
313 |
+
|
314 |
+
|
315 |
+
def download_and_extract(root: Path, info: DownloadInfo) -> None:
|
316 |
+
"""Download and extract a dataset.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
root (Path): Root directory where the dataset will be stored.
|
320 |
+
info (DownloadInfo): Info needed to download the dataset.
|
321 |
+
"""
|
322 |
+
root.mkdir(parents=True, exist_ok=True)
|
323 |
+
|
324 |
+
# save the compressed file in the specified root directory, using the same file name as on the server
|
325 |
+
downloaded_file_path = root / info.filename if info.filename else root / info.url.split("/")[-1]
|
326 |
+
|
327 |
+
if downloaded_file_path.exists():
|
328 |
+
logger.info("Existing dataset archive found. Skipping download stage.")
|
329 |
+
else:
|
330 |
+
logger.info("Downloading the %s dataset.", info.name)
|
331 |
+
# audit url. allowing only http:// or https://
|
332 |
+
if info.url.startswith("http://") or info.url.startswith("https://"):
|
333 |
+
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar:
|
334 |
+
urlretrieve( # noqa: S310 # nosec B310
|
335 |
+
url=f"{info.url}",
|
336 |
+
filename=downloaded_file_path,
|
337 |
+
reporthook=progress_bar.update_to,
|
338 |
+
)
|
339 |
+
logger.info("Checking the hash of the downloaded file.")
|
340 |
+
check_hash(downloaded_file_path, info.hashsum)
|
341 |
+
else:
|
342 |
+
msg = f"Invalid URL to download dataset. Supported 'http://' or 'https://' but '{info.url}' is requested"
|
343 |
+
raise RuntimeError(msg)
|
344 |
+
|
345 |
+
extract(downloaded_file_path, root)
|
346 |
+
|
347 |
+
|
348 |
+
def is_within_directory(directory: Path, target: Path) -> bool:
|
349 |
+
"""Check if a target path is located within a given directory.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
directory (Path): path of the parent directory
|
353 |
+
target (Path): path of the target
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
(bool): True if the target is within the directory, False otherwise
|
357 |
+
"""
|
358 |
+
abs_directory = directory.resolve()
|
359 |
+
abs_target = target.resolve()
|
360 |
+
|
361 |
+
# TODO(djdameln): Replace with pathlib is_relative_to after switching to Python 3.10
|
362 |
+
# CVS-122655
|
363 |
+
prefix = os.path.commonprefix([abs_directory, abs_target])
|
364 |
+
return prefix == str(abs_directory)
|
anomalib/data/utils/generators/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utilities to generate synthetic data."""
|
2 |
+
|
3 |
+
# Copyright (C) 2022 Intel Corporation
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
from .perlin import random_2d_perlin
|
7 |
+
|
8 |
+
__all__ = ["random_2d_perlin"]
|