File size: 13,561 Bytes
9a67fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Single-dataset statistical comparison: control vs competitors via
Nadeau–Bengio corrected t-tests on outer folds (k-fold CV), with Holm FWER control.
No Wilcoxon. No Friedman. Test metrics are shown for context only.
"""

import argparse
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from math import sqrt

from statsmodels.stats.multitest import multipletests  # Holm
from scipy.stats import t as tdist  # t distribution for NB p-values

np.random.seed(42)

# -----------------------------
# parse helpers
# -----------------------------


def _parse_list_in_brackets(inner: str) -> Optional[List[float]]:
    parts = [p.strip() for p in inner.split(",") if p.strip()]
    out: List[float] = []
    for p in parts:
        try:
            out.append(float(p))
        except Exception:
            return None
    return out


def parse_log_file(
    log_path: Path, expected_k: int = 5
) -> Tuple[Optional[List[float]], Optional[List[float]]]:
    """
    Extract K outer-fold validation RMSE and MAE arrays.
    Looks for:
      VAL FOLD RMSEs: [v1, ..., vK]
      VAL FOLD MAEs:  [v1, ..., vK]
    """
    try:
        text = log_path.read_text()
    except Exception:
        return None, None

    rmse_match = re.search(
        r"VAL\s+FOLD\s+RMSEs:\s*\[(.*?)\]", text, flags=re.IGNORECASE
    )
    mae_match = re.search(r"VAL\s+FOLD\s+MAEs:\s*\[(.*?)\]", text, flags=re.IGNORECASE)
    if not (rmse_match and mae_match):
        return None, None

    rmse_scores = _parse_list_in_brackets(rmse_match.group(1))
    mae_scores = _parse_list_in_brackets(mae_match.group(1))
    if rmse_scores is None or mae_scores is None:
        return None, None
    if len(rmse_scores) != expected_k or len(mae_scores) != expected_k:
        return None, None
    return rmse_scores, mae_scores


def parse_csv_file(csv_path: Path) -> Optional[Dict[str, float]]:
    """
    Parse *_final_results.csv with expected columns:
      val_rmse_mean, val_rmse_std, val_mae_mean, val_mae_std,
      test_rmse_mean, test_rmse_ci_low, test_rmse_ci_high,
      test_mae_mean,  test_mae_ci_low,  test_mae_ci_high
    """
    try:
        df = pd.read_csv(csv_path)
    except Exception:
        return None
    cols = [
        "val_rmse_mean",
        "val_rmse_std",
        "val_mae_mean",
        "val_mae_std",
        "test_rmse_mean",
        "test_rmse_ci_low",
        "test_rmse_ci_high",
        "test_mae_mean",
        "test_mae_ci_low",
        "test_mae_ci_high",
    ]
    if not all(c in df.columns for c in cols):
        return None
    row = df.iloc[0]
    return {k: float(row[k]) for k in cols}


def clean_base_and_exp_id(base_name: str) -> Tuple[str, str]:
    """
    Strip trailing _YYYYMMDD_HHMMSS. exp_id := first two tokens (e.g., 'gat_ecfp'),
    or 'polyatomic_polyatomic'.
    """
    clean = re.sub(r"_\d{8}_\d{6}$", "", base_name)
    toks = clean.split("_")
    exp_id = "_".join(toks[:2]) if len(toks) >= 2 else clean
    return clean, exp_id


def discover_dataset(results_dir: Path, expected_k: int = 5) -> Dict[str, Dict]:
    """
    Scan one dataset dir. Keep first valid (log,csv) per exp_id.
    Returns:
      data[exp_id] = {
        'rmse_folds': [K] or None,
        'mae_folds' : [K] or None,
        'summary'   : {...},
        'log_path'  : str,
        'csv_path'  : str
      }
    """
    data: Dict[str, Dict] = {}
    for model_dir in results_dir.iterdir():
        if not model_dir.is_dir():
            continue
        for log_file in sorted(model_dir.glob("*.txt")):
            base = log_file.stem
            clean, exp_id = clean_base_and_exp_id(base)
            csv_file = log_file.with_name(f"{clean}_final_results.csv")
            if not csv_file.exists():
                continue
            rmse_scores, mae_scores = parse_log_file(log_file, expected_k=expected_k)
            summary_stats = parse_csv_file(csv_file)
            if summary_stats is None:
                continue
            if exp_id not in data:
                data[exp_id] = {
                    "rmse_folds": rmse_scores,
                    "mae_folds": mae_scores,
                    "summary": summary_stats,
                    "log_path": str(log_file),
                    "csv_path": str(csv_file),
                }
    return data


# -----------------------------
# Nadeau–Bengio corrected t (within dataset)
# -----------------------------


def nb_corrected_t(diffs: np.ndarray, k: int) -> Tuple[float, float, float]:
    """
    Nadeau–Bengio corrected resampled t for k-fold CV differences.
    diffs: array length k with (competitor - control) per fold (loss metric).
           Negative mean favors control. Returns (t_stat, SE_NB, mean_diff).
    Correction uses rho0 = 1/(k-1) → SE_NB = sqrt((1/k + 1/(k-1)) * s^2),
    where s^2 is unbiased sample variance across folds.
    """
    diffs = np.asarray(diffs, dtype=float)
    if diffs.size != k:
        raise ValueError("diffs length must equal k")
    mean_d = float(diffs.mean())
    if k <= 1:
        return np.nan, np.nan, mean_d
    s2 = float(np.var(diffs, ddof=1))
    rho0 = 1.0 / (k - 1.0)
    se_nb = sqrt((1.0 / k + rho0) * s2)
    t_stat = (
        mean_d / se_nb
        if se_nb > 0
        else (-np.inf if mean_d < 0 else np.inf if mean_d > 0 else 0.0)
    )
    return t_stat, se_nb, mean_d


# -----------------------------
# main
# -----------------------------


def main():
    ap = argparse.ArgumentParser(
        description="Single-dataset comparison: NB-corrected t on outer folds (control vs all) with Holm; prints Test metrics for context."
    )
    ap.add_argument(
        "--results_dir",
        type=str,
        required=True,
        help="Dataset directory (e.g., ./logs_hyperparameter/qm9) with model-family subfolders.",
    )
    ap.add_argument(
        "--output_dir", type=str, required=True, help="Directory to save the report."
    )
    ap.add_argument(
        "--exp_name",
        type=str,
        required=True,
        help="Base name for the output report file.",
    )
    ap.add_argument(
        "--control_model",
        type=str,
        required=True,
        help="Control exp_id (e.g., 'polyatomic_polyatomic'). Unique substring allowed.",
    )
    ap.add_argument(
        "--k", type=int, default=5, help="Expected number of outer folds (default: 5)."
    )
    ap.add_argument(
        "--alpha",
        type=float,
        default=0.05,
        help="FWER level for Holm; also used for NB CIs if --print_ci.",
    )
    ap.add_argument(
        "--print_ci",
        action="store_true",
        help="Also print NB-style CIs for the mean fold difference.",
    )
    args = ap.parse_args()

    results_dir = Path(args.results_dir)
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    report_path = out_dir / f"{args.exp_name}_NB_single_dataset.txt"

    data = discover_dataset(results_dir, expected_k=args.k)
    if not data:
        raise SystemExit(
            f"No models found under {results_dir} with valid logs + *_final_results.csv"
        )

    exp_ids = sorted(data.keys())

    # match control (exact or unique substring)
    ctrl = args.control_model
    if ctrl not in data:
        matches = [e for e in exp_ids if ctrl.lower() in e.lower()]
        if len(matches) == 1:
            ctrl = matches[0]
        else:
            raise SystemExit(
                f"Control '{args.control_model}' not found. Available exp_ids: {exp_ids}"
            )

    with open(report_path, "w") as f:
        f.write("=" * 110 + "\n")
        f.write(
            f"Dataset: {results_dir.name} — Control vs competitors (NB-corrected t on outer folds; Holm across competitors)\n"
        )
        f.write("=" * 110 + "\n\n")
        f.write(f"Control exp_id: {ctrl}\n")
        f.write(f"k folds: {args.k}, alpha: {args.alpha}\n\n")

        # context table
        header = f"{'Model (exp_id)':<26} | {'Test RMSE (95% CI)':<30} | {'Test MAE (95% CI)':<30} | {'Val RMSE mean±sd':<22} | {'Val MAE mean±sd':<22}\n"
        f.write(header)
        f.write("-" * len(header) + "\n")
        for exp_id in exp_ids:
            s = data[exp_id]["summary"]
            line = (
                f"{exp_id:<26} | "
                f"{s['test_rmse_mean']:.6f} [{s['test_rmse_ci_low']:.6f}, {s['test_rmse_ci_high']:.6f}]  | "
                f"{s['test_mae_mean']:.6f}  [{s['test_mae_ci_low']:.6f},  {s['test_mae_ci_high']:.6f}]  | "
                f"{s['val_rmse_mean']:.6f} ± {s['val_rmse_std']:.6f}    | "
                f"{s['val_mae_mean']:.6f} ± {s['val_mae_std']:.6f}\n"
            )
            f.write(line)
        f.write("\n")

        # NB tests per competitor (RMSE and MAE on outer folds)
        comps = [e for e in exp_ids if e != ctrl]
        rows = []
        pvals_rmse, labels_rmse = [], []
        pvals_mae, labels_mae = [], []

        ctrl_rmse = np.array(
            data[ctrl]["rmse_folds"] if data[ctrl]["rmse_folds"] is not None else [],
            dtype=float,
        )
        ctrl_mae = np.array(
            data[ctrl]["mae_folds"] if data[ctrl]["mae_folds"] is not None else [],
            dtype=float,
        )

        if ctrl_rmse.size != args.k or ctrl_mae.size != args.k:
            f.write(
                "WARNING: Control model missing complete fold arrays; NB testing cannot proceed.\n"
            )
        else:
            for comp in comps:
                comp_rmse = data[comp]["rmse_folds"]
                comp_mae = data[comp]["mae_folds"]
                if comp_rmse is None or comp_mae is None:
                    continue
                comp_rmse = np.array(comp_rmse, dtype=float)
                comp_mae = np.array(comp_mae, dtype=float)
                if comp_rmse.size != args.k or comp_mae.size != args.k:
                    continue

                # diffs = competitor - control (losses); control better ⇒ positive mean
                d_rmse = comp_rmse - ctrl_rmse
                d_mae = comp_mae - ctrl_mae

                t_rmse, se_rmse, mean_rmse = nb_corrected_t(d_rmse, k=args.k)
                t_mae, se_mae, mean_mae = nb_corrected_t(d_mae, k=args.k)

                df = args.k - 1
                # RIGHT (upper tail): control better ⇢ mean(comp - ctrl) > 0
                p_rmse = float(tdist.sf(t_rmse, df=df))  # sf = 1 - cdf
                p_mae = float(tdist.sf(t_mae, df=df))

                pvals_rmse.append(p_rmse)
                labels_rmse.append(f"{ctrl} vs {comp}")
                pvals_mae.append(p_mae)
                labels_mae.append(f"{ctrl} vs {comp}")

                row = {
                    "comparison": f"{ctrl} vs {comp}",
                    "mean_diff_RMSE(comp-ctrl)": mean_rmse,
                    "t_NB_RMSE": t_rmse,
                    "p_one_sided_RMSE": p_rmse,
                    "mean_diff_MAE(comp-ctrl)": mean_mae,
                    "t_NB_MAE": t_mae,
                    "p_one_sided_MAE": p_mae,
                }

                if args.print_ci:
                    tcrit = float(tdist.ppf(1 - args.alpha / 2, df=df))
                    row["NB_CI_RMSE_low"] = mean_rmse - tcrit * se_rmse
                    row["NB_CI_RMSE_high"] = mean_rmse + tcrit * se_rmse
                    row["NB_CI_MAE_low"] = mean_mae - tcrit * se_mae
                    row["NB_CI_MAE_high"] = mean_mae + tcrit * se_mae

                rows.append(row)

        df_rows = pd.DataFrame(rows)
        if not df_rows.empty:
            f.write("--- NB-corrected t (outer folds) per competitor ---\n")
            f.write(
                df_rows.to_string(index=False, float_format=lambda x: f"{x:.6f}")
                + "\n\n"
            )
        else:
            f.write(
                "No comparable competitors with complete fold arrays. Nothing to test.\n\n"
            )

        # Holm across competitors per metric family
        def holm_table(pvals: List[float], labels: List[str], title: str):
            if not pvals:
                f.write(f"{title}\n(no p-values)\n\n")
                return
            reject, p_adj, _, _ = multipletests(
                pvals=pvals, alpha=args.alpha, method="holm"
            )
            df = pd.DataFrame(
                {
                    "comparison": labels,
                    "p_raw": pvals,
                    "p_holm": p_adj,
                    "Significant": reject.astype(bool),
                }
            )
            f.write(title + "\n")
            f.write(
                df.sort_values("p_raw").to_string(
                    index=False, float_format=lambda x: f"{x:.6f}"
                )
                + "\n\n"
            )

        holm_table(
            pvals_rmse, labels_rmse, "--- Holm-adjusted p-values (RMSE family) ---"
        )
        holm_table(
            pvals_mae, labels_mae, "--- Holm-adjusted p-values (MAE family)  ---"
        )

        f.write("=" * 110 + "\nNotes:\n")
        f.write(
            "• Tests are within-dataset, one-sided for control superiority, on outer-fold differences with Nadeau–Bengio SE correction (df = k-1).\n"
        )
        f.write(
            "• Holm controls family-wise error across competitors per metric family.\n"
        )
        f.write(
            "• Held-out Test metrics above are for context only; no fold-based omnibus tests are used.\n"
        )

    print(f"Report successfully written to: {report_path}")


if __name__ == "__main__":
    main()