File size: 2,848 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""Normalization callback utils."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import importlib
from lightning.pytorch import Callback
from omegaconf import DictConfig
from anomalib.utils.normalization import NormalizationMethod
from anomalib.utils.types import NORMALIZATION
from .min_max_normalization import _MinMaxNormalizationCallback
def get_normalization_callback(
normalization_method: NORMALIZATION = NormalizationMethod.MIN_MAX,
) -> Callback | None:
"""Return normalization object.
normalization_method is an instance of ``Callback``, it is returned as is.
if normalization_method is of type ``NormalizationMethod``, then a new class is created based on the type of
normalization_method.
Otherwise it expects a dictionary containing class_path and init_args.
normalization_method:
class_path: MinMaxNormalizer
init_args:
-
-
Example:
>>> normalizer = get_normalization_callback(NormalizationMethod.MIN_MAX)
or
>>> normalizer = get_normalization_callback("min_max")
or
>>> normalizer = get_normalization_callback({"class_path": "MinMaxNormalizationCallback", "init_args": {}})
or
>>> normalizer = get_normalization_callback(MinMaxNormalizationCallback())
"""
normalizer: Callback | None
if isinstance(normalization_method, NormalizationMethod | str):
normalizer = _get_normalizer_from_method(NormalizationMethod(normalization_method))
elif isinstance(normalization_method, Callback):
normalizer = normalization_method
elif isinstance(normalization_method, DictConfig):
normalizer = _parse_normalizer_config(normalization_method)
else:
msg = f"Unknown normalizer type {normalization_method}"
raise TypeError(msg)
return normalizer
def _get_normalizer_from_method(normalization_method: NormalizationMethod | str) -> Callback | None:
if normalization_method == NormalizationMethod.NONE:
normalizer = None
elif normalization_method == NormalizationMethod.MIN_MAX:
normalizer = _MinMaxNormalizationCallback()
else:
msg = f"Unknown normalization method {normalization_method}"
raise ValueError(msg)
return normalizer
def _parse_normalizer_config(normalization_method: DictConfig) -> Callback:
class_path = normalization_method.class_path
init_args = normalization_method.init_args
if len(class_path.split(".")) == 1:
module_path = "anomalib.utils.callbacks.normalization"
else:
module_path = ".".join(class_path.split(".")[:-1])
class_path = class_path.split(".")[-1]
module = importlib.import_module(module_path)
class_ = getattr(module, class_path)
return class_(**init_args)
|