File size: 3,373 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Anomalib Model Checkpoint Callback."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint as LightningCheckpoint
from lightning.pytorch.trainer.states import TrainerFn

from anomalib import LearningType


class ModelCheckpoint(LightningCheckpoint):
    """Anomalib Model Checkpoint Callback.

    This class overrides the Lightning ModelCheckpoint callback to enable saving checkpoints without running any
    training steps. This is useful for zero-/few-shot models, where the fit sequence only consists of validation.

    To enable saving checkpoints without running any training steps, we need to override two checks which are being
    called in the ``on_validation_end`` method of the parent class:
    - ``_should_save_on_train_epoch_end``: This method checks whether the checkpoint should be saved at the end of a
        training epoch, or at the end of the validation sequence. We modify this method to default to saving at the end
        of the validation sequence when the model is of zero- or few-shot type, unless ``save_on_train_epoch_end`` is
        specifically set by the user.
    - ``_should_skip_saving_checkpoint``: This method checks whether the checkpoint should be saved at all. We modify
        this method to allow saving during both the ``FITTING`` and ``VALIDATING`` states. In addition, we allow saving
        if the global step has not changed since the last checkpoint, but only for zero- and few-shot models. This is
        needed because both the last global step and the last checkpoint remain unchanged during zero-/few-shot
        training, which would otherwise prevent saving checkpoints during validation.
    """

    def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
        """Checks whether the checkpoint should be saved.

        Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow
        saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models).
        """
        is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
        return (
            bool(trainer.fast_dev_run)  # disable checkpointing with fast_dev_run
            or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING]  # don't save anything during non-fit
            or trainer.sanity_checking  # don't save anything during sanity check
            or (self._last_global_step_saved == trainer.global_step and not is_zero_or_few_shot)
        )

    def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool:
        """Checks whether the checkpoint should be saved at the end of a training epoch or validation sequence.

        Overrides the parent method to default to saving at the end of the validation sequence when the model is of
        zero- or few-shot type, unless ``save_on_train_epoch_end`` is specifically set by the user.
        """
        if self._save_on_train_epoch_end is not None:
            return self._save_on_train_epoch_end

        if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
            return False

        return super()._should_save_on_train_epoch_end(trainer)