|
"""Sample evaluation script for track 2.""" |
|
|
|
import os |
|
|
|
|
|
os.environ['TORCH_HOME'] = './checkpoint' |
|
os.environ['HF_HOME'] = './checkpoint/huggingface' |
|
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers' |
|
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub' |
|
|
|
|
|
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True) |
|
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True) |
|
|
|
import argparse |
|
import importlib |
|
import importlib.util |
|
|
|
import torch |
|
import logging |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
from anomalib.data import MVTecLoco |
|
from anomalib.metrics.f1_max import F1Max |
|
from anomalib.metrics.auroc import AUROC |
|
from tabulate import tabulate |
|
import numpy as np |
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Parse command line arguments. |
|
|
|
Returns: |
|
argparse.Namespace: Parsed arguments. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--module_path", type=str, required=True) |
|
parser.add_argument("--class_name", default='MyModel', type=str, required=False) |
|
parser.add_argument("--weights_path", type=str, required=False) |
|
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False) |
|
parser.add_argument("--category", type=str, required=True) |
|
parser.add_argument("--viz", action='store_true', default=False) |
|
return parser.parse_args() |
|
|
|
|
|
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module: |
|
"""Load model. |
|
|
|
Args: |
|
module_path (str): Path to the module containing the model class. |
|
class_name (str): Name of the model class. |
|
weights_path (str): Path to the model weights. |
|
|
|
Returns: |
|
nn.Module: Loaded model. |
|
""" |
|
|
|
model_class = getattr(importlib.import_module(module_path), class_name) |
|
|
|
model = model_class() |
|
|
|
if weights_path: |
|
model.load_state_dict(torch.load(weights_path)) |
|
return model |
|
|
|
|
|
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None: |
|
"""Run the evaluation script. |
|
|
|
Args: |
|
module_path (str): Path to the module containing the model class. |
|
class_name (str): Name of the model class. |
|
weights_path (str): Path to the model weights. |
|
dataset_path (str): Path to the dataset. |
|
category (str): Category of the dataset. |
|
""" |
|
|
|
logging.getLogger().setLevel(logging.ERROR) |
|
logging.getLogger('anomalib').setLevel(logging.ERROR) |
|
logging.getLogger('lightning').setLevel(logging.ERROR) |
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
|
|
|
|
model = load_model(module_path, class_name, weights_path) |
|
model.to(device) |
|
|
|
|
|
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category) |
|
datamodule.setup() |
|
|
|
model.set_viz(viz) |
|
model.set_save_coreset_features(True) |
|
|
|
|
|
FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) |
|
|
|
|
|
setup_data = { |
|
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device), |
|
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES], |
|
"dataset_category": category, |
|
} |
|
model.setup(setup_data) |
|
|
|
print(f"✓ Coreset computation completed for {category}") |
|
print(f" Memory bank features saved to memory_bank/ directory") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz) |
|
|