File size: 7,279 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""Visualizer Callback.
This is assigned by Anomalib Engine internally.
"""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path
from typing import Any, cast
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from anomalib.data.utils.image import save_image, show_image
from anomalib.loggers import AnomalibWandbLogger
from anomalib.loggers.base import ImageLoggerBase
from anomalib.models import AnomalyModule
from anomalib.utils.visualization import (
BaseVisualizer,
GeneratorResult,
VisualizationStep,
)
logger = logging.getLogger(__name__)
class _VisualizationCallback(Callback):
"""Callback for visualization that is used internally by the Engine.
Args:
visualizers (BaseVisualizer | list[BaseVisualizer]):
Visualizer objects that are used for computing the visualizations. Defaults to None.
save (bool, optional): Save the image. Defaults to False.
root (Path | None, optional): The path to save the images. Defaults to None.
log (bool, optional): Log the images into the loggers. Defaults to False.
show (bool, optional): Show the images. Defaults to False.
Example:
>>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
>>> visualization_callback = _VisualizationCallback(
... visualizers=visualizers,
... save=True,
... root="results/images"
... )
CLI
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers ImageVisualizer \
--visualization.visualizers+=MetricsVisualizer
or
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'
Raises:
ValueError: Incase `root` is None and `save` is True.
"""
def __init__(
self,
visualizers: BaseVisualizer | list[BaseVisualizer],
save: bool = False,
root: Path | None = None,
log: bool = False,
show: bool = False,
) -> None:
self.save = save
if save and root is None:
msg = "`root` must be provided if save is True"
raise ValueError(msg)
self.root: Path = root if root is not None else Path() # need this check for mypy
self.log = log
self.show = show
self.generators = visualizers if isinstance(visualizers, list) else [visualizers]
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.BATCH:
for result in generator(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_idx=batch_idx,
dataloader_idx=dataloader_idx,
):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
# Get the filename to save the image.
# Filename is split based on the datamodule name and category.
# For example, if the filename is `MVTec/bottle/000.png`, then the
# filename is split based on `MVTec/bottle` and `000.png` is saved.
if trainer.datamodule is not None:
filename = str(result.file_name).split(
sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}",
)[-1]
else:
filename = Path(result.file_name).name
save_image(image=result.image, root=self.root, filename=filename)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)
def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.STAGE_END:
for result in generator(trainer=trainer, pl_module=pl_module):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
save_image(image=result.image, root=self.root, filename=result.file_name)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)
for logger in trainer.loggers:
if isinstance(logger, AnomalibWandbLogger):
logger.save()
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
return self.on_test_end(trainer, pl_module)
def _add_to_logger(
self,
result: GeneratorResult,
module: AnomalyModule,
trainer: Trainer,
) -> None:
"""Add image to logger.
Args:
result (GeneratorResult): Output from the generators.
module (AnomalyModule): LightningModule from which the global step is extracted.
trainer (Trainer): Trainer object.
"""
# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
for logger in trainer.loggers
}
# save image to respective logger
if result.file_name is None:
msg = "File name is None"
raise ValueError(msg)
filename = result.file_name
image = result.image
for log_to in available_loggers:
# check if logger object is same as the requested object
if isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
_name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
logger.add_image(
image=image,
name=_name,
global_step=module.global_step,
)
|