File size: 1,910 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Callbacks for Anomalib models."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from importlib import import_module
from pathlib import Path

import yaml
from jsonargparse import Namespace
from lightning.pytorch.callbacks import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf

from .checkpoint import ModelCheckpoint
from .graph import GraphLogger
from .model_loader import LoadModelCallback
from .tiler_configuration import TilerConfigurationCallback
from .timer import TimerCallback

__all__ = [
    "ModelCheckpoint",
    "GraphLogger",
    "LoadModelCallback",
    "TilerConfigurationCallback",
    "TimerCallback",
]


logger = logging.getLogger(__name__)


def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]:
    """Return base callbacks for all the lightning models.

    Args:
        config (DictConfig | ListConfig | Namespace): Model config

    Return:
        (list[Callback]): List of callbacks.
    """
    logger.info("Loading the callbacks")

    callbacks: list[Callback] = []

    if "ckpt_path" in config.trainer and config.ckpt_path is not None:
        load_model = LoadModelCallback(config.ckpt_path)
        callbacks.append(load_model)

    if "optimization" in config and "nncf" in config.optimization and config.optimization.nncf.apply:
        # NNCF wraps torch's jit which conflicts with kornia's jit calls.
        # Hence, nncf is imported only when required
        nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
        nncf_callback = nncf_module.NNCFCallback
        nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
        callbacks.append(
            nncf_callback(
                config=nncf_config,
                export_dir=str(Path(config.project.path) / "compressed"),
            ),
        )

    return callbacks