PACT-Net / logs /old-res /create_table.py
rk-random's picture
Upload folder using huggingface_hub
9a67fbe verified
raw
history blame
2.55 kB
import os
import re
from collections import defaultdict
LOG_ROOT = "logs"
SUMMARY_FILE = "benchmark_summary.txt"
def parse_ci_section(content):
"""Extracts Test RMSE, MAE, and their 95% CI intervals from log content."""
rmse_match = re.search(
r"Test RMSE: ([\d.]+)\s+\(95 % CI: ([\d.]+)[–-]([\d.]+)\)", content
)
mae_match = re.search(
r"Test MAE\s*: ([\d.]+)\s+\(95 % CI: ([\d.]+)[–-]([\d.]+)\)", content
)
if not (rmse_match and mae_match):
return None
rmse, rmse_lo, rmse_hi = map(float, rmse_match.groups())
mae, mae_lo, mae_hi = map(float, mae_match.groups())
rmse_pm = (rmse_hi - rmse_lo) / 2
mae_pm = (mae_hi - mae_lo) / 2
return {"rmse": rmse, "rmse_pm": rmse_pm, "mae": mae, "mae_pm": mae_pm}
def extract_encoding(fname):
for enc in ["ecfp", "smiles", "selfies"]:
if enc in fname.lower():
return enc
return "unknown"
# Map benchmark name → list of rows
benchmark_tables = defaultdict(list)
for benchmark_dir in os.listdir(LOG_ROOT):
if not benchmark_dir.endswith("-bench"):
continue
benchmark_name = benchmark_dir.replace("-bench", "")
benchmark_path = os.path.join(LOG_ROOT, benchmark_dir)
for model in os.listdir(benchmark_path):
model_path = os.path.join(benchmark_path, model)
if not os.path.isdir(model_path):
continue
for log_file in os.listdir(model_path):
if not log_file.endswith(".txt"):
continue
encoding = extract_encoding(log_file)
with open(os.path.join(model_path, log_file), "r") as f:
content = f.read()
parsed = parse_ci_section(content)
if parsed is None:
continue
benchmark_tables[benchmark_name].append(
{"model": model, "encoding": encoding, **parsed}
)
# Write to file
with open(SUMMARY_FILE, "w") as f:
def log(line=""):
f.write(line + "\n")
print(line)
for benchmark, rows in sorted(benchmark_tables.items()):
log(f"\n{benchmark.upper()} Benchmark Summary:\n")
log(f"{'Model':<13} {'Encoding':<8} {'RMSE ±':<20} {'MAE ±':<20}")
log("-" * 65)
for row in sorted(rows, key=lambda r: (r["model"], r["encoding"])):
log(
f"{row['model']:<13} {row['encoding']:<8} "
f"{row['rmse']:.4f} ± {row['rmse_pm']:.4f} "
f"{row['mae']:.4f} ± {row['mae_pm']:.4f}"
)