File size: 4,508 Bytes
74acc06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Sample evaluation script for track 2."""

import os

# Set cache directories to use checkpoint folder for model downloads
os.environ['TORCH_HOME'] = './checkpoint'
os.environ['HF_HOME'] = './checkpoint/huggingface'
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'

# Create checkpoint subdirectories if they don't exist
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)

import argparse
import importlib
import importlib.util

import torch
import logging
from torch import nn

# NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
# It will be available in v1.1.0 which will be released on April 29th, 2024.
# If you are using an earlier version of anomalib, you could install anomalib
# from the anomalib source code from the following branch:
# https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
from anomalib.data import MVTecLoco
from anomalib.metrics.f1_max import F1Max
from anomalib.metrics.auroc import AUROC
from tabulate import tabulate
import numpy as np

# FEW_SHOT_SAMPLES = [0, 1, 2, 3]

def parse_args() -> argparse.Namespace:
    """Parse command line arguments.

    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--module_path", type=str, required=True)
    parser.add_argument("--class_name", default='MyModel', type=str, required=False)
    parser.add_argument("--weights_path", type=str, required=False)
    parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
    parser.add_argument("--category", type=str, required=True)
    parser.add_argument("--viz", action='store_true', default=False)
    return parser.parse_args()


def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
    """Load model.

    Args:
        module_path (str): Path to the module containing the model class.
        class_name (str): Name of the model class.
        weights_path (str): Path to the model weights.

    Returns:
        nn.Module: Loaded model.
    """
    # get model class
    model_class = getattr(importlib.import_module(module_path), class_name)
    # instantiate model
    model = model_class()
    # load weights
    if weights_path:
        model.load_state_dict(torch.load(weights_path))
    return model


def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
    """Run the evaluation script.

    Args:
        module_path (str): Path to the module containing the model class.
        class_name (str): Name of the model class.
        weights_path (str): Path to the model weights.
        dataset_path (str): Path to the dataset.
        category (str): Category of the dataset.
    """
    # Disable verbose logging from all libraries
    logging.getLogger().setLevel(logging.ERROR)
    logging.getLogger('anomalib').setLevel(logging.ERROR)
    logging.getLogger('lightning').setLevel(logging.ERROR)
    logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Instantiate model class here
    # Load the model here from checkpoint.
    model = load_model(module_path, class_name, weights_path)
    model.to(device)

    # Create the dataset
    datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
    datamodule.setup()

    model.set_viz(viz)
    model.set_save_coreset_features(True)
    

    FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # traverse all dataset to build coreset
    
    # pass few-shot images and dataset category to model
    setup_data = {
        "few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
        "few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
        "dataset_category": category,
    }
    model.setup(setup_data)
    
    print(f"✓ Coreset computation completed for {category}")
    print(f"  Memory bank features saved to memory_bank/ directory")
    


if __name__ == "__main__":
    args = parse_args()
    run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)