File size: 2,848 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Normalization callback utils."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import importlib

from lightning.pytorch import Callback
from omegaconf import DictConfig

from anomalib.utils.normalization import NormalizationMethod
from anomalib.utils.types import NORMALIZATION

from .min_max_normalization import _MinMaxNormalizationCallback


def get_normalization_callback(
    normalization_method: NORMALIZATION = NormalizationMethod.MIN_MAX,
) -> Callback | None:
    """Return normalization object.

    normalization_method is an instance of ``Callback``, it is returned as is.

    if normalization_method is of type ``NormalizationMethod``, then a new class is created based on the type of
    normalization_method.

    Otherwise it expects a dictionary containing class_path and init_args.
        normalization_method:
            class_path: MinMaxNormalizer
            init_args:
                -
                -

    Example:
        >>> normalizer = get_normalization_callback(NormalizationMethod.MIN_MAX)
        or
        >>> normalizer = get_normalization_callback("min_max")
        or
        >>> normalizer = get_normalization_callback({"class_path": "MinMaxNormalizationCallback", "init_args": {}})
        or
        >>> normalizer = get_normalization_callback(MinMaxNormalizationCallback())
    """
    normalizer: Callback | None
    if isinstance(normalization_method, NormalizationMethod | str):
        normalizer = _get_normalizer_from_method(NormalizationMethod(normalization_method))
    elif isinstance(normalization_method, Callback):
        normalizer = normalization_method
    elif isinstance(normalization_method, DictConfig):
        normalizer = _parse_normalizer_config(normalization_method)
    else:
        msg = f"Unknown normalizer type {normalization_method}"
        raise TypeError(msg)
    return normalizer


def _get_normalizer_from_method(normalization_method: NormalizationMethod | str) -> Callback | None:
    if normalization_method == NormalizationMethod.NONE:
        normalizer = None
    elif normalization_method == NormalizationMethod.MIN_MAX:
        normalizer = _MinMaxNormalizationCallback()
    else:
        msg = f"Unknown normalization method {normalization_method}"
        raise ValueError(msg)
    return normalizer


def _parse_normalizer_config(normalization_method: DictConfig) -> Callback:
    class_path = normalization_method.class_path
    init_args = normalization_method.init_args

    if len(class_path.split(".")) == 1:
        module_path = "anomalib.utils.callbacks.normalization"
    else:
        module_path = ".".join(class_path.split(".")[:-1])
        class_path = class_path.split(".")[-1]
    module = importlib.import_module(module_path)
    class_ = getattr(module, class_path)
    return class_(**init_args)