|
"""Sample evaluation script for track 2.""" |
|
|
|
import os |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
|
|
os.environ['TORCH_HOME'] = './checkpoint' |
|
os.environ['HF_HOME'] = './checkpoint/huggingface' |
|
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers' |
|
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub' |
|
|
|
|
|
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True) |
|
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True) |
|
|
|
import argparse |
|
import importlib |
|
import importlib.util |
|
|
|
import torch |
|
import logging |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
from anomalib.data import MVTecLoco |
|
from anomalib.metrics.f1_max import F1Max |
|
from anomalib.metrics.auroc import AUROC |
|
from tabulate import tabulate |
|
import numpy as np |
|
|
|
FEW_SHOT_SAMPLES = [0, 1, 2, 3] |
|
|
|
def write_results_to_markdown(category, results_data, module_path): |
|
"""Write evaluation results to markdown file. |
|
|
|
Args: |
|
category (str): Dataset category name |
|
results_data (dict): Dictionary containing all metrics |
|
module_path (str): Model module path (for protocol identification) |
|
""" |
|
|
|
protocol = "Few-shot" if "few_shot" in module_path else "Full-data" |
|
|
|
|
|
results_dir = Path("results") |
|
results_dir.mkdir(exist_ok=True) |
|
|
|
|
|
protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data" |
|
combined_file = results_dir / f"{protocol_suffix}_results.md" |
|
|
|
|
|
existing_results = {} |
|
if combined_file.exists(): |
|
with open(combined_file, 'r') as f: |
|
content = f.read() |
|
|
|
lines = content.split('\n') |
|
for line in lines: |
|
if '|' in line and line.count('|') >= 8: |
|
parts = [p.strip() for p in line.split('|')] |
|
if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----': |
|
existing_results[parts[1]] = { |
|
'k_shots': parts[2], |
|
'f1_image': parts[3], |
|
'auroc_image': parts[4], |
|
'f1_logical': parts[5], |
|
'auroc_logical': parts[6], |
|
'f1_structural': parts[7], |
|
'auroc_structural': parts[8] |
|
} |
|
|
|
|
|
existing_results[category] = { |
|
'k_shots': str(len(FEW_SHOT_SAMPLES)), |
|
'f1_image': f"{results_data['f1_image']:.2f}", |
|
'auroc_image': f"{results_data['auroc_image']:.2f}", |
|
'f1_logical': f"{results_data['f1_logical']:.2f}", |
|
'auroc_logical': f"{results_data['auroc_logical']:.2f}", |
|
'f1_structural': f"{results_data['f1_structural']:.2f}", |
|
'auroc_structural': f"{results_data['auroc_structural']:.2f}" |
|
} |
|
|
|
|
|
with open(combined_file, 'w') as f: |
|
f.write(f"# All Categories - {protocol} Protocol Results\n\n") |
|
f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") |
|
f.write(f"**Protocol:** {protocol}\n") |
|
f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n") |
|
|
|
f.write("## Summary Table\n\n") |
|
f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n") |
|
f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n") |
|
|
|
|
|
for cat in sorted(existing_results.keys()): |
|
data = existing_results[cat] |
|
f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n") |
|
|
|
print(f"\n✓ Results saved to:") |
|
print(f" - Combined: {combined_file}") |
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Parse command line arguments. |
|
|
|
Returns: |
|
argparse.Namespace: Parsed arguments. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--module_path", type=str, required=True) |
|
parser.add_argument("--class_name", default='MyModel', type=str, required=False) |
|
parser.add_argument("--weights_path", type=str, required=False) |
|
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False) |
|
parser.add_argument("--category", type=str, required=True) |
|
parser.add_argument("--viz", action='store_true', default=False) |
|
return parser.parse_args() |
|
|
|
|
|
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module: |
|
"""Load model. |
|
|
|
Args: |
|
module_path (str): Path to the module containing the model class. |
|
class_name (str): Name of the model class. |
|
weights_path (str): Path to the model weights. |
|
|
|
Returns: |
|
nn.Module: Loaded model. |
|
""" |
|
|
|
model_class = getattr(importlib.import_module(module_path), class_name) |
|
|
|
model = model_class() |
|
|
|
if weights_path: |
|
model.load_state_dict(torch.load(weights_path)) |
|
return model |
|
|
|
|
|
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None: |
|
"""Run the evaluation script. |
|
|
|
Args: |
|
module_path (str): Path to the module containing the model class. |
|
class_name (str): Name of the model class. |
|
weights_path (str): Path to the model weights. |
|
dataset_path (str): Path to the dataset. |
|
category (str): Category of the dataset. |
|
""" |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
|
|
|
|
model = load_model(module_path, class_name, weights_path) |
|
model.to(device) |
|
|
|
|
|
|
|
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category) |
|
datamodule.setup() |
|
|
|
model.set_viz(viz) |
|
|
|
|
|
|
|
image_metric = F1Max() |
|
pixel_metric = F1Max() |
|
|
|
image_metric_logical = F1Max() |
|
image_metric_structure = F1Max() |
|
|
|
image_metric_auroc = AUROC() |
|
pixel_metric_auroc = AUROC() |
|
|
|
image_metric_auroc_logical = AUROC() |
|
image_metric_auroc_structure = AUROC() |
|
|
|
|
|
|
|
|
|
setup_data = { |
|
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device), |
|
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES], |
|
"dataset_category": category, |
|
} |
|
model.setup(setup_data) |
|
|
|
|
|
for data in datamodule.test_dataloader(): |
|
with torch.no_grad(): |
|
image_path = data['image_path'] |
|
output = model(data["image"].to(device), data['image_path']) |
|
|
|
image_metric.update(output["pred_score"].cpu(), data["label"]) |
|
image_metric_auroc.update(output["pred_score"].cpu(), data["label"]) |
|
|
|
if 'logical' not in image_path[0]: |
|
image_metric_structure.update(output["pred_score"].cpu(), data["label"]) |
|
image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"]) |
|
if 'structural' not in image_path[0]: |
|
image_metric_logical.update(output["pred_score"].cpu(), data["label"]) |
|
image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"]) |
|
|
|
|
|
|
|
|
|
logging.getLogger().setLevel(logging.ERROR) |
|
logging.getLogger('anomalib').setLevel(logging.ERROR) |
|
logging.getLogger('lightning').setLevel(logging.ERROR) |
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) |
|
|
|
|
|
logger = logging.getLogger('evaluation') |
|
logger.handlers.clear() |
|
logger.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') |
|
console_handler = logging.StreamHandler() |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
table_ls = [[category, |
|
str(len(FEW_SHOT_SAMPLES)), |
|
str(np.round(image_metric.compute().item() * 100, decimals=2)), |
|
str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)), |
|
|
|
|
|
str(np.round(image_metric_logical.compute().item() * 100, decimals=2)), |
|
str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)), |
|
str(np.round(image_metric_structure.compute().item() * 100, decimals=2)), |
|
str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)), |
|
]] |
|
|
|
results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe") |
|
|
|
logger.info("\n%s", results) |
|
|
|
|
|
results_data = { |
|
'f1_image': np.round(image_metric.compute().item() * 100, decimals=2), |
|
'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2), |
|
'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2), |
|
'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2), |
|
'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2), |
|
'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2) |
|
} |
|
write_results_to_markdown(category, results_data, module_path) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz) |
|
|