File size: 1,247 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Callback that loads model weights from the state dict."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging

import torch
from lightning.pytorch import Callback, Trainer

from anomalib.models.components import AnomalyModule

logger = logging.getLogger(__name__)


class LoadModelCallback(Callback):
    """Callback that loads the model weights from the state dict.

    Examples:
        >>> from anomalib.callbacks import LoadModelCallback
        >>> from anomalib.engine import Engine
        ...
        >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")]
        >>> engine = Engine(callbacks=callbacks)
    """

    def __init__(self, weights_path: str) -> None:
        self.weights_path = weights_path

    def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
        """Call when inference begins.

        Loads the model weights from ``weights_path`` into the PyTorch module.
        """
        del trainer, stage  # These variables are not used.

        logger.info("Loading the model from %s", self.weights_path)
        pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])