File size: 7,279 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""Visualizer Callback.

This is assigned by Anomalib Engine internally.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from pathlib import Path
from typing import Any, cast

from lightning.pytorch import Callback, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.data.utils.image import save_image, show_image
from anomalib.loggers import AnomalibWandbLogger
from anomalib.loggers.base import ImageLoggerBase
from anomalib.models import AnomalyModule
from anomalib.utils.visualization import (
    BaseVisualizer,
    GeneratorResult,
    VisualizationStep,
)

logger = logging.getLogger(__name__)


class _VisualizationCallback(Callback):
    """Callback for visualization that is used internally by the Engine.

    Args:
        visualizers (BaseVisualizer | list[BaseVisualizer]):
            Visualizer objects that are used for computing the visualizations. Defaults to None.
        save (bool, optional): Save the image. Defaults to False.
        root (Path | None, optional): The path to save the images. Defaults to None.
        log (bool, optional): Log the images into the loggers. Defaults to False.
        show (bool, optional): Show the images. Defaults to False.

    Example:
        >>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
        >>> visualization_callback = _VisualizationCallback(
        ... visualizers=visualizers,
        ...   save=True,
        ...   root="results/images"
        ... )

        CLI
        $ anomalib train --model Padim --data MVTec \
            --visualization.visualizers ImageVisualizer \
            --visualization.visualizers+=MetricsVisualizer
        or
        $ anomalib train --model Padim --data MVTec \
            --visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'

    Raises:
        ValueError: Incase `root` is None and `save` is True.
    """

    def __init__(
        self,
        visualizers: BaseVisualizer | list[BaseVisualizer],
        save: bool = False,
        root: Path | None = None,
        log: bool = False,
        show: bool = False,
    ) -> None:
        self.save = save
        if save and root is None:
            msg = "`root` must be provided if save is True"
            raise ValueError(msg)
        self.root: Path = root if root is not None else Path()  # need this check for mypy
        self.log = log
        self.show = show
        self.generators = visualizers if isinstance(visualizers, list) else [visualizers]

    def on_test_batch_end(
        self,
        trainer: Trainer,
        pl_module: AnomalyModule,
        outputs: STEP_OUTPUT | None,
        batch: Any,  # noqa: ANN401
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        for generator in self.generators:
            if generator.visualize_on == VisualizationStep.BATCH:
                for result in generator(
                    trainer=trainer,
                    pl_module=pl_module,
                    outputs=outputs,
                    batch=batch,
                    batch_idx=batch_idx,
                    dataloader_idx=dataloader_idx,
                ):
                    if self.save:
                        if result.file_name is None:
                            msg = "``save`` is set to ``True`` but file name is ``None``"
                            raise ValueError(msg)

                        # Get the filename to save the image.
                        # Filename is split based on the datamodule name and category.
                        # For example, if the filename is `MVTec/bottle/000.png`, then the
                        # filename is split based on `MVTec/bottle` and `000.png` is saved.
                        if trainer.datamodule is not None:
                            filename = str(result.file_name).split(
                                sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}",
                            )[-1]
                        else:
                            filename = Path(result.file_name).name
                        save_image(image=result.image, root=self.root, filename=filename)
                    if self.show:
                        show_image(image=result.image, title=str(result.file_name))
                    if self.log:
                        self._add_to_logger(result, pl_module, trainer)

    def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
        for generator in self.generators:
            if generator.visualize_on == VisualizationStep.STAGE_END:
                for result in generator(trainer=trainer, pl_module=pl_module):
                    if self.save:
                        if result.file_name is None:
                            msg = "``save`` is set to ``True`` but file name is ``None``"
                            raise ValueError(msg)
                        save_image(image=result.image, root=self.root, filename=result.file_name)
                    if self.show:
                        show_image(image=result.image, title=str(result.file_name))
                    if self.log:
                        self._add_to_logger(result, pl_module, trainer)

        for logger in trainer.loggers:
            if isinstance(logger, AnomalibWandbLogger):
                logger.save()

    def on_predict_batch_end(
        self,
        trainer: Trainer,
        pl_module: AnomalyModule,
        outputs: STEP_OUTPUT | None,
        batch: Any,  # noqa: ANN401
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

    def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
        return self.on_test_end(trainer, pl_module)

    def _add_to_logger(
        self,
        result: GeneratorResult,
        module: AnomalyModule,
        trainer: Trainer,
    ) -> None:
        """Add image to logger.

        Args:
            result (GeneratorResult): Output from the generators.
            module (AnomalyModule): LightningModule from which the global step is extracted.
            trainer (Trainer): Trainer object.
        """
        # Store names of logger and the logger in a dict
        available_loggers = {
            type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
            for logger in trainer.loggers
        }
        # save image to respective logger
        if result.file_name is None:
            msg = "File name is None"
            raise ValueError(msg)
        filename = result.file_name
        image = result.image
        for log_to in available_loggers:
            # check if logger object is same as the requested object
            if isinstance(available_loggers[log_to], ImageLoggerBase):
                logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to])  # placate mypy
                _name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
                logger.add_image(
                    image=image,
                    name=_name,
                    global_step=module.global_step,
                )