|
|
import os |
|
|
import re |
|
|
from collections import defaultdict |
|
|
|
|
|
LOG_ROOT = "logs" |
|
|
SUMMARY_FILE = "benchmark_summary.txt" |
|
|
|
|
|
|
|
|
def parse_ci_section(content): |
|
|
"""Extracts Test RMSE, MAE, and their 95% CI intervals from log content.""" |
|
|
rmse_match = re.search( |
|
|
r"Test RMSE: ([\d.]+)\s+\(95 % CI: ([\d.]+)[–-]([\d.]+)\)", content |
|
|
) |
|
|
mae_match = re.search( |
|
|
r"Test MAE\s*: ([\d.]+)\s+\(95 % CI: ([\d.]+)[–-]([\d.]+)\)", content |
|
|
) |
|
|
|
|
|
if not (rmse_match and mae_match): |
|
|
return None |
|
|
|
|
|
rmse, rmse_lo, rmse_hi = map(float, rmse_match.groups()) |
|
|
mae, mae_lo, mae_hi = map(float, mae_match.groups()) |
|
|
|
|
|
rmse_pm = (rmse_hi - rmse_lo) / 2 |
|
|
mae_pm = (mae_hi - mae_lo) / 2 |
|
|
|
|
|
return {"rmse": rmse, "rmse_pm": rmse_pm, "mae": mae, "mae_pm": mae_pm} |
|
|
|
|
|
|
|
|
def extract_encoding(fname): |
|
|
for enc in ["ecfp", "smiles", "selfies"]: |
|
|
if enc in fname.lower(): |
|
|
return enc |
|
|
return "unknown" |
|
|
|
|
|
|
|
|
|
|
|
benchmark_tables = defaultdict(list) |
|
|
|
|
|
for benchmark_dir in os.listdir(LOG_ROOT): |
|
|
if not benchmark_dir.endswith("-bench"): |
|
|
continue |
|
|
|
|
|
benchmark_name = benchmark_dir.replace("-bench", "") |
|
|
benchmark_path = os.path.join(LOG_ROOT, benchmark_dir) |
|
|
|
|
|
for model in os.listdir(benchmark_path): |
|
|
model_path = os.path.join(benchmark_path, model) |
|
|
if not os.path.isdir(model_path): |
|
|
continue |
|
|
|
|
|
for log_file in os.listdir(model_path): |
|
|
if not log_file.endswith(".txt"): |
|
|
continue |
|
|
|
|
|
encoding = extract_encoding(log_file) |
|
|
with open(os.path.join(model_path, log_file), "r") as f: |
|
|
content = f.read() |
|
|
|
|
|
parsed = parse_ci_section(content) |
|
|
if parsed is None: |
|
|
continue |
|
|
|
|
|
benchmark_tables[benchmark_name].append( |
|
|
{"model": model, "encoding": encoding, **parsed} |
|
|
) |
|
|
|
|
|
|
|
|
with open(SUMMARY_FILE, "w") as f: |
|
|
|
|
|
def log(line=""): |
|
|
f.write(line + "\n") |
|
|
print(line) |
|
|
|
|
|
for benchmark, rows in sorted(benchmark_tables.items()): |
|
|
log(f"\n{benchmark.upper()} Benchmark Summary:\n") |
|
|
log(f"{'Model':<13} {'Encoding':<8} {'RMSE ±':<20} {'MAE ±':<20}") |
|
|
log("-" * 65) |
|
|
|
|
|
for row in sorted(rows, key=lambda r: (r["model"], r["encoding"])): |
|
|
log( |
|
|
f"{row['model']:<13} {row['encoding']:<8} " |
|
|
f"{row['rmse']:.4f} ± {row['rmse_pm']:.4f} " |
|
|
f"{row['mae']:.4f} ± {row['mae_pm']:.4f}" |
|
|
) |
|
|
|