File size: 8,560 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""Utils for NNCf optimization."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from copy import copy
from typing import TYPE_CHECKING, Any

import torch
from nncf import NNCFConfig
from nncf.api.compression import CompressionAlgorithmController
from nncf.torch import create_compressed_model, load_state, register_default_init_args
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from torch import nn
from torch.utils.data.dataloader import DataLoader

if TYPE_CHECKING:
    from collections.abc import Iterator


logger = logging.getLogger(name="NNCF compression")


class InitLoader(PTInitializingDataLoader):
    """Initializing data loader for NNCF to be used with unsupervised training algorithms."""

    def __init__(self, data_loader: DataLoader) -> None:
        super().__init__(data_loader)
        self._data_loader_iter: Iterator

    def __iter__(self) -> "InitLoader":
        """Create iterator for dataloader."""
        self._data_loader_iter = iter(self._data_loader)
        return self

    def __next__(self) -> torch.Tensor:
        """Return next item from dataloader iterator."""
        loaded_item = next(self._data_loader_iter)
        return loaded_item["image"]

    def get_inputs(self, dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]:
        """Get input to model.

        Returns:
            (dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during
            the initialization process
        """
        return (dataloader_output,), {}

    def get_target(self, _):  # noqa: ANN001, ANN201
        """Return structure for ground truth in loss criterion based on dataloader output.

        This implementation does not do anything and is a placeholder.

        Returns:
            None
        """
        return


def wrap_nncf_model(
    model: nn.Module,
    config: dict,
    dataloader: DataLoader,
    init_state_dict: dict,
) -> tuple[CompressionAlgorithmController, NNCFNetwork]:
    """Wrap model by NNCF.

    :param model: Anomalib model.
    :param config: NNCF config.
    :param dataloader: Dataloader for initialization of NNCF model.
    :param init_state_dict: Opti
    :return: compression controller, compressed model
    """
    nncf_config = NNCFConfig.from_dict(config)

    if not dataloader and not init_state_dict:
        logger.warning(
            "Either dataloader or NNCF pre-trained "
            "model checkpoint should be set. Without this, "
            "quantizers will not be initialized",
        )

    compression_state = None
    resuming_state_dict = None
    if init_state_dict:
        resuming_state_dict = init_state_dict.get("model")
        compression_state = init_state_dict.get("compression_state")

    if dataloader:
        init_loader = InitLoader(dataloader)
        nncf_config = register_default_init_args(nncf_config, init_loader)

    nncf_ctrl, nncf_model = create_compressed_model(
        model=model,
        config=nncf_config,
        dump_graphs=False,
        compression_state=compression_state,
    )

    if resuming_state_dict:
        load_state(nncf_model, resuming_state_dict, is_resume=True)

    return nncf_ctrl, nncf_model


def is_state_nncf(state: dict) -> bool:
    """Check if state is the result of NNCF-compressed model."""
    return bool(state.get("meta", {}).get("nncf_enable_compression", False))


def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict:
    """Compose NNCf config by selected options.

    :param nncf_config:
    :param enabled_options:
    :return: config
    """
    optimisation_parts = nncf_config
    optimisation_parts_to_choose = []
    if "order_of_parts" in optimisation_parts:
        # The result of applying the changes from optimisation parts
        # may depend on the order of applying the changes
        # (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
        #  but for sparsity it is required `total_epochs=50`)
        # So, user can define `order_of_parts` in the optimisation_config
        # to specify the order of applying the parts.
        order_of_parts = optimisation_parts["order_of_parts"]
        if not isinstance(order_of_parts, list):
            msg = 'The field "order_of_parts" in optimization config should be a list'
            raise TypeError(msg)

        for part in enabled_options:
            if part not in order_of_parts:
                msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}"
                raise ValueError(msg)

        optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]

    if "base" not in optimisation_parts:
        msg = 'Error: the optimisation config does not contain the "base" part'
        raise KeyError(msg)
    nncf_config_part = optimisation_parts["base"]

    for part in optimisation_parts_to_choose:
        if part not in optimisation_parts:
            msg = f'Error: the optimisation config does not contain the part "{part}"'
            raise KeyError(msg)
        optimisation_part_dict = optimisation_parts[part]
        try:
            nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
        except AssertionError as cur_error:
            err_descr = (
                f"Error during merging the parts of nncf configs:\n"
                f"the current part={part}, "
                f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
                f"The error is:\n{cur_error}"
            )
            raise RuntimeError(err_descr) from None

    return nncf_config_part


def merge_dicts_and_lists_b_into_a(
    a: dict[Any, Any] | list[Any],
    b: dict[Any, Any] | list[Any],
) -> dict[Any, Any] | list[Any]:
    """Merge dict configs.

    Args:
        a (dict[Any, Any] | list[Any]): First dict or list.
        b (dict[Any, Any] | list[Any]): Second dict or list.

    Returns:
        dict[Any, Any] | list[Any]: Merged dict or list.
    """
    return _merge_dicts_and_lists_b_into_a(a, b, "")


def _merge_dicts_and_lists_b_into_a(
    a: dict[Any, Any] | list[Any],
    b: dict[Any, Any] | list[Any],
    cur_key: int | str | None = None,
) -> dict[Any, Any] | list[Any]:
    """Merge dict configs.

        * works with usual dicts and lists and derived types
        * supports merging of lists (by concatenating the lists)
        * makes recursive merging for dict + dict case
        * overwrites when merging scalar into scalar
        Note that we merge b into a (whereas Config makes merge a into b),
        since otherwise the order of list merging is counter-intuitive.

    Args:
        a (dict[Any, Any] | list[Any]): First dict or list.
        b (dict[Any, Any] | list[Any]): Second dict or list.
        cur_key (int | str | None, optional): key for current level of recursion. Defaults to None.

    Returns:
        dict[Any, Any] | list[Any]: Merged dict or list.
    """

    def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str:
        _key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`"
        return (
            f"Error in merging parts of config: different types {_key_str},"
            f" type(a) = {type(_a)},"
            f" type(b) = {type(_b)}"
        )

    if not (isinstance(a, dict | list)):
        msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}"
        raise TypeError(msg)

    if not (isinstance(b, dict | list)):
        raise TypeError(_err_str(a, b, cur_key))

    if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)):
        raise TypeError(_err_str(a, b, cur_key))

    if isinstance(a, list) and isinstance(b, list):
        # the main diff w.r.t. mmcf.Config -- merging of lists
        return a + b

    a = copy(a)
    for k in b:
        if k not in a:
            a[k] = copy(b[k])
            continue
        new_cur_key = str(cur_key) + "." + k if cur_key else k
        if isinstance(a[k], dict | list):
            a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
            continue

        if any(isinstance(b[k], t) for t in [dict, list]):
            raise TypeError(_err_str(a[k], b[k], new_cur_key))

        # suppose here that a[k] and b[k] are scalars, just overwrite
        a[k] = b[k]
    return a