"""Sample evaluation script for track 2.""" import os from datetime import datetime from pathlib import Path # Set cache directories to use checkpoint folder for model downloads os.environ['TORCH_HOME'] = './checkpoint' os.environ['HF_HOME'] = './checkpoint/huggingface' os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers' os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub' # Create checkpoint subdirectories if they don't exist os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True) os.makedirs('./checkpoint/huggingface/hub', exist_ok=True) import argparse import importlib import importlib.util import torch import logging from torch import nn # NOTE: The following MVTecLoco import is not available in anomalib v1.0.1. # It will be available in v1.1.0 which will be released on April 29th, 2024. # If you are using an earlier version of anomalib, you could install anomalib # from the anomalib source code from the following branch: # https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco from anomalib.data import MVTecLoco from anomalib.metrics.f1_max import F1Max from anomalib.metrics.auroc import AUROC from tabulate import tabulate import numpy as np FEW_SHOT_SAMPLES = [0, 1, 2, 3] def write_results_to_markdown(category, results_data, module_path): """Write evaluation results to markdown file. Args: category (str): Dataset category name results_data (dict): Dictionary containing all metrics module_path (str): Model module path (for protocol identification) """ # Determine protocol type from module path protocol = "Few-shot" if "few_shot" in module_path else "Full-data" # Create results directory results_dir = Path("results") results_dir.mkdir(exist_ok=True) # Combined results file with simple protocol name protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data" combined_file = results_dir / f"{protocol_suffix}_results.md" # Read existing results if file exists existing_results = {} if combined_file.exists(): with open(combined_file, 'r') as f: content = f.read() # Parse existing results (basic parsing) lines = content.split('\n') for line in lines: if '|' in line and line.count('|') >= 8: parts = [p.strip() for p in line.split('|')] if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----': existing_results[parts[1]] = { 'k_shots': parts[2], 'f1_image': parts[3], 'auroc_image': parts[4], 'f1_logical': parts[5], 'auroc_logical': parts[6], 'f1_structural': parts[7], 'auroc_structural': parts[8] } # Add current results existing_results[category] = { 'k_shots': str(len(FEW_SHOT_SAMPLES)), 'f1_image': f"{results_data['f1_image']:.2f}", 'auroc_image': f"{results_data['auroc_image']:.2f}", 'f1_logical': f"{results_data['f1_logical']:.2f}", 'auroc_logical': f"{results_data['auroc_logical']:.2f}", 'f1_structural': f"{results_data['f1_structural']:.2f}", 'auroc_structural': f"{results_data['auroc_structural']:.2f}" } # Write combined results with open(combined_file, 'w') as f: f.write(f"# All Categories - {protocol} Protocol Results\n\n") f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write(f"**Protocol:** {protocol}\n") f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n") f.write("## Summary Table\n\n") f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n") f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n") # Sort categories alphabetically for cat in sorted(existing_results.keys()): data = existing_results[cat] f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n") print(f"\n✓ Results saved to:") print(f" - Combined: {combined_file}") def parse_args() -> argparse.Namespace: """Parse command line arguments. Returns: argparse.Namespace: Parsed arguments. """ parser = argparse.ArgumentParser() parser.add_argument("--module_path", type=str, required=True) parser.add_argument("--class_name", default='MyModel', type=str, required=False) parser.add_argument("--weights_path", type=str, required=False) parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False) parser.add_argument("--category", type=str, required=True) parser.add_argument("--viz", action='store_true', default=False) return parser.parse_args() def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module: """Load model. Args: module_path (str): Path to the module containing the model class. class_name (str): Name of the model class. weights_path (str): Path to the model weights. Returns: nn.Module: Loaded model. """ # get model class model_class = getattr(importlib.import_module(module_path), class_name) # instantiate model model = model_class() # load weights if weights_path: model.load_state_dict(torch.load(weights_path)) return model def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None: """Run the evaluation script. Args: module_path (str): Path to the module containing the model class. class_name (str): Name of the model class. weights_path (str): Path to the model weights. dataset_path (str): Path to the dataset. category (str): Category of the dataset. """ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Instantiate model class here # Load the model here from checkpoint. model = load_model(module_path, class_name, weights_path) model.to(device) # # Create the dataset datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category) datamodule.setup() model.set_viz(viz) # # Create the metrics image_metric = F1Max() pixel_metric = F1Max() image_metric_logical = F1Max() image_metric_structure = F1Max() image_metric_auroc = AUROC() pixel_metric_auroc = AUROC() image_metric_auroc_logical = AUROC() image_metric_auroc_structure = AUROC() # # pass few-shot images and dataset category to model setup_data = { "few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device), "few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES], "dataset_category": category, } model.setup(setup_data) # Loop over the test set and compute the metrics for data in datamodule.test_dataloader(): with torch.no_grad(): image_path = data['image_path'] output = model(data["image"].to(device), data['image_path']) image_metric.update(output["pred_score"].cpu(), data["label"]) image_metric_auroc.update(output["pred_score"].cpu(), data["label"]) if 'logical' not in image_path[0]: image_metric_structure.update(output["pred_score"].cpu(), data["label"]) image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"]) if 'structural' not in image_path[0]: image_metric_logical.update(output["pred_score"].cpu(), data["label"]) image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"]) # Disable verbose logging from all libraries logging.getLogger().setLevel(logging.ERROR) logging.getLogger('anomalib').setLevel(logging.ERROR) logging.getLogger('lightning').setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) # Set up our own logger for results only logger = logging.getLogger('evaluation') logger.handlers.clear() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) table_ls = [[category, str(len(FEW_SHOT_SAMPLES)), str(np.round(image_metric.compute().item() * 100, decimals=2)), str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)), # str(np.round(pixel_metric.compute().item() * 100, decimals=2)), # str(np.round(pixel_metric_auroc.compute().item() * 100, decimals=2)), str(np.round(image_metric_logical.compute().item() * 100, decimals=2)), str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)), str(np.round(image_metric_structure.compute().item() * 100, decimals=2)), str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)), ]] results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe") logger.info("\n%s", results) # Save results to markdown results_data = { 'f1_image': np.round(image_metric.compute().item() * 100, decimals=2), 'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2), 'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2), 'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2), 'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2), 'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2) } write_results_to_markdown(category, results_data, module_path) if __name__ == "__main__": args = parse_args() run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)