LogSAD / anomalib /callbacks /visualizer.py
zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
raw
history blame
7.28 kB
"""Visualizer Callback.
This is assigned by Anomalib Engine internally.
"""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path
from typing import Any, cast
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from anomalib.data.utils.image import save_image, show_image
from anomalib.loggers import AnomalibWandbLogger
from anomalib.loggers.base import ImageLoggerBase
from anomalib.models import AnomalyModule
from anomalib.utils.visualization import (
BaseVisualizer,
GeneratorResult,
VisualizationStep,
)
logger = logging.getLogger(__name__)
class _VisualizationCallback(Callback):
"""Callback for visualization that is used internally by the Engine.
Args:
visualizers (BaseVisualizer | list[BaseVisualizer]):
Visualizer objects that are used for computing the visualizations. Defaults to None.
save (bool, optional): Save the image. Defaults to False.
root (Path | None, optional): The path to save the images. Defaults to None.
log (bool, optional): Log the images into the loggers. Defaults to False.
show (bool, optional): Show the images. Defaults to False.
Example:
>>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
>>> visualization_callback = _VisualizationCallback(
... visualizers=visualizers,
... save=True,
... root="results/images"
... )
CLI
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers ImageVisualizer \
--visualization.visualizers+=MetricsVisualizer
or
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'
Raises:
ValueError: Incase `root` is None and `save` is True.
"""
def __init__(
self,
visualizers: BaseVisualizer | list[BaseVisualizer],
save: bool = False,
root: Path | None = None,
log: bool = False,
show: bool = False,
) -> None:
self.save = save
if save and root is None:
msg = "`root` must be provided if save is True"
raise ValueError(msg)
self.root: Path = root if root is not None else Path() # need this check for mypy
self.log = log
self.show = show
self.generators = visualizers if isinstance(visualizers, list) else [visualizers]
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.BATCH:
for result in generator(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_idx=batch_idx,
dataloader_idx=dataloader_idx,
):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
# Get the filename to save the image.
# Filename is split based on the datamodule name and category.
# For example, if the filename is `MVTec/bottle/000.png`, then the
# filename is split based on `MVTec/bottle` and `000.png` is saved.
if trainer.datamodule is not None:
filename = str(result.file_name).split(
sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}",
)[-1]
else:
filename = Path(result.file_name).name
save_image(image=result.image, root=self.root, filename=filename)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)
def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.STAGE_END:
for result in generator(trainer=trainer, pl_module=pl_module):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
save_image(image=result.image, root=self.root, filename=result.file_name)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)
for logger in trainer.loggers:
if isinstance(logger, AnomalibWandbLogger):
logger.save()
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
return self.on_test_end(trainer, pl_module)
def _add_to_logger(
self,
result: GeneratorResult,
module: AnomalyModule,
trainer: Trainer,
) -> None:
"""Add image to logger.
Args:
result (GeneratorResult): Output from the generators.
module (AnomalyModule): LightningModule from which the global step is extracted.
trainer (Trainer): Trainer object.
"""
# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
for logger in trainer.loggers
}
# save image to respective logger
if result.file_name is None:
msg = "File name is None"
raise ValueError(msg)
filename = result.file_name
image = result.image
for log_to in available_loggers:
# check if logger object is same as the requested object
if isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
_name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
logger.add_image(
image=image,
name=_name,
global_step=module.global_step,
)