File size: 1,319 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""Memory Bank Module."""
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
import torch
from torch import nn
class MemoryBankMixin(nn.Module):
"""Memory Bank Lightning Module.
This module is used to implement memory bank lightning modules.
It checks if the model is fitted before validation starts.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer("_is_fitted", torch.tensor([False]))
self._is_fitted: torch.Tensor
@abstractmethod
def fit(self) -> None:
"""Fit the model to the data."""
msg = (
f"fit method not implemented for {self.__class__.__name__}. "
"To use a memory-bank module, implement ``fit.``"
)
raise NotImplementedError(msg)
def on_validation_start(self) -> None:
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
def on_train_epoch_end(self) -> None:
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
|