|
"""Callbacks for Anomalib models.""" |
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from importlib import import_module |
|
from pathlib import Path |
|
|
|
import yaml |
|
from jsonargparse import Namespace |
|
from lightning.pytorch.callbacks import Callback |
|
from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
|
from .checkpoint import ModelCheckpoint |
|
from .graph import GraphLogger |
|
from .model_loader import LoadModelCallback |
|
from .tiler_configuration import TilerConfigurationCallback |
|
from .timer import TimerCallback |
|
|
|
__all__ = [ |
|
"ModelCheckpoint", |
|
"GraphLogger", |
|
"LoadModelCallback", |
|
"TilerConfigurationCallback", |
|
"TimerCallback", |
|
] |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]: |
|
"""Return base callbacks for all the lightning models. |
|
|
|
Args: |
|
config (DictConfig | ListConfig | Namespace): Model config |
|
|
|
Return: |
|
(list[Callback]): List of callbacks. |
|
""" |
|
logger.info("Loading the callbacks") |
|
|
|
callbacks: list[Callback] = [] |
|
|
|
if "ckpt_path" in config.trainer and config.ckpt_path is not None: |
|
load_model = LoadModelCallback(config.ckpt_path) |
|
callbacks.append(load_model) |
|
|
|
if "optimization" in config and "nncf" in config.optimization and config.optimization.nncf.apply: |
|
|
|
|
|
nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") |
|
nncf_callback = nncf_module.NNCFCallback |
|
nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) |
|
callbacks.append( |
|
nncf_callback( |
|
config=nncf_config, |
|
export_dir=str(Path(config.project.path) / "compressed"), |
|
), |
|
) |
|
|
|
return callbacks |
|
|