|
"""Visualizer Callback. |
|
|
|
This is assigned by Anomalib Engine internally. |
|
""" |
|
|
|
|
|
|
|
|
|
import logging |
|
from pathlib import Path |
|
from typing import Any, cast |
|
|
|
from lightning.pytorch import Callback, Trainer |
|
from lightning.pytorch.utilities.types import STEP_OUTPUT |
|
|
|
from anomalib.data.utils.image import save_image, show_image |
|
from anomalib.loggers import AnomalibWandbLogger |
|
from anomalib.loggers.base import ImageLoggerBase |
|
from anomalib.models import AnomalyModule |
|
from anomalib.utils.visualization import ( |
|
BaseVisualizer, |
|
GeneratorResult, |
|
VisualizationStep, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class _VisualizationCallback(Callback): |
|
"""Callback for visualization that is used internally by the Engine. |
|
|
|
Args: |
|
visualizers (BaseVisualizer | list[BaseVisualizer]): |
|
Visualizer objects that are used for computing the visualizations. Defaults to None. |
|
save (bool, optional): Save the image. Defaults to False. |
|
root (Path | None, optional): The path to save the images. Defaults to None. |
|
log (bool, optional): Log the images into the loggers. Defaults to False. |
|
show (bool, optional): Show the images. Defaults to False. |
|
|
|
Example: |
|
>>> visualizers = [ImageVisualizer(), MetricsVisualizer()] |
|
>>> visualization_callback = _VisualizationCallback( |
|
... visualizers=visualizers, |
|
... save=True, |
|
... root="results/images" |
|
... ) |
|
|
|
CLI |
|
$ anomalib train --model Padim --data MVTec \ |
|
--visualization.visualizers ImageVisualizer \ |
|
--visualization.visualizers+=MetricsVisualizer |
|
or |
|
$ anomalib train --model Padim --data MVTec \ |
|
--visualization.visualizers '[ImageVisualizer, MetricsVisualizer]' |
|
|
|
Raises: |
|
ValueError: Incase `root` is None and `save` is True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
visualizers: BaseVisualizer | list[BaseVisualizer], |
|
save: bool = False, |
|
root: Path | None = None, |
|
log: bool = False, |
|
show: bool = False, |
|
) -> None: |
|
self.save = save |
|
if save and root is None: |
|
msg = "`root` must be provided if save is True" |
|
raise ValueError(msg) |
|
self.root: Path = root if root is not None else Path() |
|
self.log = log |
|
self.show = show |
|
self.generators = visualizers if isinstance(visualizers, list) else [visualizers] |
|
|
|
def on_test_batch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
outputs: STEP_OUTPUT | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
for generator in self.generators: |
|
if generator.visualize_on == VisualizationStep.BATCH: |
|
for result in generator( |
|
trainer=trainer, |
|
pl_module=pl_module, |
|
outputs=outputs, |
|
batch=batch, |
|
batch_idx=batch_idx, |
|
dataloader_idx=dataloader_idx, |
|
): |
|
if self.save: |
|
if result.file_name is None: |
|
msg = "``save`` is set to ``True`` but file name is ``None``" |
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
if trainer.datamodule is not None: |
|
filename = str(result.file_name).split( |
|
sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}", |
|
)[-1] |
|
else: |
|
filename = Path(result.file_name).name |
|
save_image(image=result.image, root=self.root, filename=filename) |
|
if self.show: |
|
show_image(image=result.image, title=str(result.file_name)) |
|
if self.log: |
|
self._add_to_logger(result, pl_module, trainer) |
|
|
|
def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: |
|
for generator in self.generators: |
|
if generator.visualize_on == VisualizationStep.STAGE_END: |
|
for result in generator(trainer=trainer, pl_module=pl_module): |
|
if self.save: |
|
if result.file_name is None: |
|
msg = "``save`` is set to ``True`` but file name is ``None``" |
|
raise ValueError(msg) |
|
save_image(image=result.image, root=self.root, filename=result.file_name) |
|
if self.show: |
|
show_image(image=result.image, title=str(result.file_name)) |
|
if self.log: |
|
self._add_to_logger(result, pl_module, trainer) |
|
|
|
for logger in trainer.loggers: |
|
if isinstance(logger, AnomalibWandbLogger): |
|
logger.save() |
|
|
|
def on_predict_batch_end( |
|
self, |
|
trainer: Trainer, |
|
pl_module: AnomalyModule, |
|
outputs: STEP_OUTPUT | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) |
|
|
|
def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: |
|
return self.on_test_end(trainer, pl_module) |
|
|
|
def _add_to_logger( |
|
self, |
|
result: GeneratorResult, |
|
module: AnomalyModule, |
|
trainer: Trainer, |
|
) -> None: |
|
"""Add image to logger. |
|
|
|
Args: |
|
result (GeneratorResult): Output from the generators. |
|
module (AnomalyModule): LightningModule from which the global step is extracted. |
|
trainer (Trainer): Trainer object. |
|
""" |
|
|
|
available_loggers = { |
|
type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger |
|
for logger in trainer.loggers |
|
} |
|
|
|
if result.file_name is None: |
|
msg = "File name is None" |
|
raise ValueError(msg) |
|
filename = result.file_name |
|
image = result.image |
|
for log_to in available_loggers: |
|
|
|
if isinstance(available_loggers[log_to], ImageLoggerBase): |
|
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) |
|
_name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename |
|
logger.add_image( |
|
image=image, |
|
name=_name, |
|
global_step=module.global_step, |
|
) |
|
|