|
"""Utils for NNCf optimization.""" |
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from copy import copy |
|
from typing import TYPE_CHECKING, Any |
|
|
|
import torch |
|
from nncf import NNCFConfig |
|
from nncf.api.compression import CompressionAlgorithmController |
|
from nncf.torch import create_compressed_model, load_state, register_default_init_args |
|
from nncf.torch.initialization import PTInitializingDataLoader |
|
from nncf.torch.nncf_network import NNCFNetwork |
|
from torch import nn |
|
from torch.utils.data.dataloader import DataLoader |
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Iterator |
|
|
|
|
|
logger = logging.getLogger(name="NNCF compression") |
|
|
|
|
|
class InitLoader(PTInitializingDataLoader): |
|
"""Initializing data loader for NNCF to be used with unsupervised training algorithms.""" |
|
|
|
def __init__(self, data_loader: DataLoader) -> None: |
|
super().__init__(data_loader) |
|
self._data_loader_iter: Iterator |
|
|
|
def __iter__(self) -> "InitLoader": |
|
"""Create iterator for dataloader.""" |
|
self._data_loader_iter = iter(self._data_loader) |
|
return self |
|
|
|
def __next__(self) -> torch.Tensor: |
|
"""Return next item from dataloader iterator.""" |
|
loaded_item = next(self._data_loader_iter) |
|
return loaded_item["image"] |
|
|
|
def get_inputs(self, dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]: |
|
"""Get input to model. |
|
|
|
Returns: |
|
(dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during |
|
the initialization process |
|
""" |
|
return (dataloader_output,), {} |
|
|
|
def get_target(self, _): |
|
"""Return structure for ground truth in loss criterion based on dataloader output. |
|
|
|
This implementation does not do anything and is a placeholder. |
|
|
|
Returns: |
|
None |
|
""" |
|
return |
|
|
|
|
|
def wrap_nncf_model( |
|
model: nn.Module, |
|
config: dict, |
|
dataloader: DataLoader, |
|
init_state_dict: dict, |
|
) -> tuple[CompressionAlgorithmController, NNCFNetwork]: |
|
"""Wrap model by NNCF. |
|
|
|
:param model: Anomalib model. |
|
:param config: NNCF config. |
|
:param dataloader: Dataloader for initialization of NNCF model. |
|
:param init_state_dict: Opti |
|
:return: compression controller, compressed model |
|
""" |
|
nncf_config = NNCFConfig.from_dict(config) |
|
|
|
if not dataloader and not init_state_dict: |
|
logger.warning( |
|
"Either dataloader or NNCF pre-trained " |
|
"model checkpoint should be set. Without this, " |
|
"quantizers will not be initialized", |
|
) |
|
|
|
compression_state = None |
|
resuming_state_dict = None |
|
if init_state_dict: |
|
resuming_state_dict = init_state_dict.get("model") |
|
compression_state = init_state_dict.get("compression_state") |
|
|
|
if dataloader: |
|
init_loader = InitLoader(dataloader) |
|
nncf_config = register_default_init_args(nncf_config, init_loader) |
|
|
|
nncf_ctrl, nncf_model = create_compressed_model( |
|
model=model, |
|
config=nncf_config, |
|
dump_graphs=False, |
|
compression_state=compression_state, |
|
) |
|
|
|
if resuming_state_dict: |
|
load_state(nncf_model, resuming_state_dict, is_resume=True) |
|
|
|
return nncf_ctrl, nncf_model |
|
|
|
|
|
def is_state_nncf(state: dict) -> bool: |
|
"""Check if state is the result of NNCF-compressed model.""" |
|
return bool(state.get("meta", {}).get("nncf_enable_compression", False)) |
|
|
|
|
|
def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict: |
|
"""Compose NNCf config by selected options. |
|
|
|
:param nncf_config: |
|
:param enabled_options: |
|
:return: config |
|
""" |
|
optimisation_parts = nncf_config |
|
optimisation_parts_to_choose = [] |
|
if "order_of_parts" in optimisation_parts: |
|
|
|
|
|
|
|
|
|
|
|
|
|
order_of_parts = optimisation_parts["order_of_parts"] |
|
if not isinstance(order_of_parts, list): |
|
msg = 'The field "order_of_parts" in optimization config should be a list' |
|
raise TypeError(msg) |
|
|
|
for part in enabled_options: |
|
if part not in order_of_parts: |
|
msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}" |
|
raise ValueError(msg) |
|
|
|
optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options] |
|
|
|
if "base" not in optimisation_parts: |
|
msg = 'Error: the optimisation config does not contain the "base" part' |
|
raise KeyError(msg) |
|
nncf_config_part = optimisation_parts["base"] |
|
|
|
for part in optimisation_parts_to_choose: |
|
if part not in optimisation_parts: |
|
msg = f'Error: the optimisation config does not contain the part "{part}"' |
|
raise KeyError(msg) |
|
optimisation_part_dict = optimisation_parts[part] |
|
try: |
|
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict) |
|
except AssertionError as cur_error: |
|
err_descr = ( |
|
f"Error during merging the parts of nncf configs:\n" |
|
f"the current part={part}, " |
|
f"the order of merging parts into base is {optimisation_parts_to_choose}.\n" |
|
f"The error is:\n{cur_error}" |
|
) |
|
raise RuntimeError(err_descr) from None |
|
|
|
return nncf_config_part |
|
|
|
|
|
def merge_dicts_and_lists_b_into_a( |
|
a: dict[Any, Any] | list[Any], |
|
b: dict[Any, Any] | list[Any], |
|
) -> dict[Any, Any] | list[Any]: |
|
"""Merge dict configs. |
|
|
|
Args: |
|
a (dict[Any, Any] | list[Any]): First dict or list. |
|
b (dict[Any, Any] | list[Any]): Second dict or list. |
|
|
|
Returns: |
|
dict[Any, Any] | list[Any]: Merged dict or list. |
|
""" |
|
return _merge_dicts_and_lists_b_into_a(a, b, "") |
|
|
|
|
|
def _merge_dicts_and_lists_b_into_a( |
|
a: dict[Any, Any] | list[Any], |
|
b: dict[Any, Any] | list[Any], |
|
cur_key: int | str | None = None, |
|
) -> dict[Any, Any] | list[Any]: |
|
"""Merge dict configs. |
|
|
|
* works with usual dicts and lists and derived types |
|
* supports merging of lists (by concatenating the lists) |
|
* makes recursive merging for dict + dict case |
|
* overwrites when merging scalar into scalar |
|
Note that we merge b into a (whereas Config makes merge a into b), |
|
since otherwise the order of list merging is counter-intuitive. |
|
|
|
Args: |
|
a (dict[Any, Any] | list[Any]): First dict or list. |
|
b (dict[Any, Any] | list[Any]): Second dict or list. |
|
cur_key (int | str | None, optional): key for current level of recursion. Defaults to None. |
|
|
|
Returns: |
|
dict[Any, Any] | list[Any]: Merged dict or list. |
|
""" |
|
|
|
def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str: |
|
_key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`" |
|
return ( |
|
f"Error in merging parts of config: different types {_key_str}," |
|
f" type(a) = {type(_a)}," |
|
f" type(b) = {type(_b)}" |
|
) |
|
|
|
if not (isinstance(a, dict | list)): |
|
msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}" |
|
raise TypeError(msg) |
|
|
|
if not (isinstance(b, dict | list)): |
|
raise TypeError(_err_str(a, b, cur_key)) |
|
|
|
if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)): |
|
raise TypeError(_err_str(a, b, cur_key)) |
|
|
|
if isinstance(a, list) and isinstance(b, list): |
|
|
|
return a + b |
|
|
|
a = copy(a) |
|
for k in b: |
|
if k not in a: |
|
a[k] = copy(b[k]) |
|
continue |
|
new_cur_key = str(cur_key) + "." + k if cur_key else k |
|
if isinstance(a[k], dict | list): |
|
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key) |
|
continue |
|
|
|
if any(isinstance(b[k], t) for t in [dict, list]): |
|
raise TypeError(_err_str(a[k], b[k], new_cur_key)) |
|
|
|
|
|
a[k] = b[k] |
|
return a |
|
|