|
|
""" |
|
|
Convert the plain-text summary (like the one you pasted) into a LaTeX table. |
|
|
|
|
|
Features |
|
|
- Parses blocks like: --- Dataset: QM9 --- |
|
|
- Reads Val metrics as "mean ± std" |
|
|
- Reads Test metrics as either "mean [low, high]" (CI) or "mean ± std" |
|
|
- Option --test-intervals {ci, pm}: |
|
|
ci -> keep "mean [low, high]" strings (uses text columns for the 2 Test cols) |
|
|
pm -> convert CI to ± half-width to match siunitx S columns |
|
|
- Multirow per dataset; booktabs rules; siunitx S columns; optional renaming & bolding |
|
|
|
|
|
Usage |
|
|
python latex_table_from_txt.py \ |
|
|
--input results.txt --output table.tex \ |
|
|
--test-intervals pm \ |
|
|
--rename "polyatomic=PACTNet (ECC)" \ |
|
|
--bold-contains "PACTNet" \ |
|
|
--val-dec 3 --test-dec 4 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import re |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser( |
|
|
description="Convert TXT performance summary to LaTeX table (booktabs + siunitx + multirow)." |
|
|
) |
|
|
p.add_argument("--input", "-i", type=Path, required=True, help="Input TXT file") |
|
|
p.add_argument("--output", "-o", type=Path, required=True, help="Output .tex file") |
|
|
p.add_argument( |
|
|
"--caption", |
|
|
default="Comprehensive performance comparison across all datasets and models.", |
|
|
help="LaTeX caption", |
|
|
) |
|
|
p.add_argument("--label", default="tab:full_results", help="LaTeX label") |
|
|
p.add_argument( |
|
|
"--val-dec", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Decimal places for Val metrics (mean & std)", |
|
|
) |
|
|
p.add_argument( |
|
|
"--test-dec", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Decimal places for Test metrics (mean & std)", |
|
|
) |
|
|
p.add_argument( |
|
|
"--no-fixed-decimals", |
|
|
action="store_true", |
|
|
help="Use raw decimals as provided (don't round to fixed places)", |
|
|
) |
|
|
p.add_argument( |
|
|
"--table-formats", |
|
|
nargs=4, |
|
|
default=["2.3(4)", "2.3(4)", "2.4(4)", "2.4(4)"], |
|
|
help="siunitx table-format for Val RMSE, Val MAE, Test RMSE, Test MAE", |
|
|
) |
|
|
p.add_argument( |
|
|
"--font-size", default="\\small", help="LaTeX font size inside the table" |
|
|
) |
|
|
p.add_argument("--width", default="\\textwidth", help="Width for \\resizebox") |
|
|
p.add_argument( |
|
|
"--no-resize", action="store_true", help="Disable \\resizebox wrapper" |
|
|
) |
|
|
p.add_argument( |
|
|
"--booktabs", action="store_true", default=True, help="Use booktabs rules" |
|
|
) |
|
|
p.add_argument( |
|
|
"--no-booktabs", |
|
|
dest="booktabs", |
|
|
action="store_false", |
|
|
help="Disable booktabs rules", |
|
|
) |
|
|
p.add_argument( |
|
|
"--test-intervals", |
|
|
choices=["ci", "pm"], |
|
|
default="pm", |
|
|
help="For Test metrics with CIs: keep CIs (ci) or convert to ± half-width (pm)", |
|
|
) |
|
|
p.add_argument( |
|
|
"--bold-contains", |
|
|
default=None, |
|
|
help="Regex to bold any row where model/rep cell matches", |
|
|
) |
|
|
p.add_argument( |
|
|
"--rename", |
|
|
nargs="*", |
|
|
default=[], |
|
|
help='Rename patterns like old=new (regex on the "Model (Rep.)" cell)', |
|
|
) |
|
|
p.add_argument("--dataset-order", nargs="*", help="Optional explicit dataset order") |
|
|
p.add_argument( |
|
|
"--sort-by", |
|
|
nargs="*", |
|
|
default=None, |
|
|
help="Sort keys within each dataset, e.g., --sort-by model representation", |
|
|
) |
|
|
p.add_argument( |
|
|
"--ascending", |
|
|
nargs="*", |
|
|
type=int, |
|
|
help="Ascending flags matching --sort-by, e.g. 1 0", |
|
|
) |
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def fmt_unc(mean, std, fixed_decimals: bool, dec_places: int) -> str: |
|
|
if pd.isna(mean) or pd.isna(std): |
|
|
return r"\textemdash" |
|
|
if fixed_decimals: |
|
|
return f"{float(mean):.{dec_places}f} \\pm {float(std):.{dec_places}f}" |
|
|
|
|
|
|
|
|
def tidy(x): |
|
|
s = f"{x}" |
|
|
if "e" in s or "E" in s: |
|
|
return s |
|
|
if "." in s: |
|
|
s = s.rstrip("0").rstrip(".") |
|
|
return s |
|
|
|
|
|
return f"{tidy(mean)} \\pm {tidy(std)}" |
|
|
|
|
|
|
|
|
def build_model_rep(name: str) -> tuple[str, str, str]: |
|
|
""" |
|
|
Split 'gat_ecfp' -> ('GAT', 'ECFP', 'GAT (ECFP)') |
|
|
If no underscore, rep is ''. |
|
|
""" |
|
|
name = name.strip() |
|
|
model, rep = name, "" |
|
|
if "_" in name: |
|
|
model, rep = name.split("_", 1) |
|
|
|
|
|
model_fmt = model.upper() if model.isalpha() else model |
|
|
rep_fmt = rep.upper() if rep else "" |
|
|
label = f"{model_fmt} ({rep_fmt})" if rep_fmt else model_fmt |
|
|
return model_fmt, rep_fmt, label |
|
|
|
|
|
|
|
|
def apply_renames(s: str, mapping: dict) -> str: |
|
|
for k, v in mapping.items(): |
|
|
s = re.sub(k, v, s) |
|
|
return s |
|
|
|
|
|
|
|
|
def parse_metric(cell: str): |
|
|
""" |
|
|
Returns dict with possible keys: mean, std, ci_low, ci_high |
|
|
Accepts: |
|
|
- '1.234 ± 0.056' |
|
|
- '1.234 [1.111, 1.345]' |
|
|
- '1.234' |
|
|
""" |
|
|
s = cell.strip() |
|
|
m = re.match(r"([+-]?\d+(?:\.\d+)?)\s*±\s*([+-]?\d+(?:\.\d+)?)", s) |
|
|
if m: |
|
|
return { |
|
|
"mean": float(m.group(1)), |
|
|
"std": float(m.group(2)), |
|
|
"ci_low": None, |
|
|
"ci_high": None, |
|
|
} |
|
|
m = re.match( |
|
|
r"([+-]?\d+(?:\.\d+)?)\s*\[\s*([+-]?\d+(?:\.\d+)?)\s*,\s*([+-]?\d+(?:\.\d+)?)\s*\]", |
|
|
s, |
|
|
) |
|
|
if m: |
|
|
return { |
|
|
"mean": float(m.group(1)), |
|
|
"std": None, |
|
|
"ci_low": float(m.group(2)), |
|
|
"ci_high": float(m.group(3)), |
|
|
} |
|
|
m = re.match(r"([+-]?\d+(?:\.\d+)?)$", s) |
|
|
if m: |
|
|
return {"mean": float(m.group(1)), "std": None, "ci_low": None, "ci_high": None} |
|
|
return {"mean": None, "std": None, "ci_low": None, "ci_high": None} |
|
|
|
|
|
|
|
|
def parse_txt(path: Path) -> pd.DataFrame: |
|
|
""" |
|
|
Parse the text file structure you showed into a tidy DataFrame. |
|
|
""" |
|
|
text = path.read_text(encoding="utf-8", errors="ignore") |
|
|
|
|
|
blocks = [] |
|
|
for m in re.finditer(r"---\s*Dataset:\s*(.+?)\s*---", text): |
|
|
blocks.append((m.start(), m.group(1).strip())) |
|
|
rows = [] |
|
|
for i, (pos, dataset) in enumerate(blocks): |
|
|
start = pos |
|
|
end = blocks[i + 1][0] if i + 1 < len(blocks) else len(text) |
|
|
body = text[start:end] |
|
|
|
|
|
|
|
|
table_lines = [] |
|
|
after_header = False |
|
|
for line in body.splitlines(): |
|
|
if re.search(r"\|\s*Val RMSE", line): |
|
|
after_header = True |
|
|
continue |
|
|
if after_header: |
|
|
if line.strip().startswith("--- Statistical"): |
|
|
break |
|
|
if re.match(r"\s*$", line): |
|
|
break |
|
|
|
|
|
if re.match(r"[-\s]{5,}$", line.replace("|", "")): |
|
|
continue |
|
|
if "|" in line: |
|
|
table_lines.append(line) |
|
|
|
|
|
for line in table_lines: |
|
|
parts = [p.strip() for p in line.split("|")] |
|
|
if len(parts) < 5: |
|
|
continue |
|
|
name = parts[0] |
|
|
val_rmse = parse_metric(parts[1]) |
|
|
val_mae = parse_metric(parts[2]) |
|
|
test_rmse = parse_metric(parts[3]) |
|
|
test_mae = parse_metric(parts[4]) |
|
|
model, rep, label = build_model_rep(name) |
|
|
rows.append( |
|
|
{ |
|
|
"dataset": dataset, |
|
|
"model": model, |
|
|
"representation": rep, |
|
|
"label": label, |
|
|
"val_rmse_mean": val_rmse["mean"], |
|
|
"val_rmse_std": val_rmse["std"], |
|
|
"val_mae_mean": val_mae["mean"], |
|
|
"val_mae_std": val_mae["std"], |
|
|
"test_rmse_mean": test_rmse["mean"], |
|
|
"test_rmse_std": test_rmse["std"], |
|
|
"test_rmse_ci_low": test_rmse["ci_low"], |
|
|
"test_rmse_ci_high": test_rmse["ci_high"], |
|
|
"test_mae_mean": test_mae["mean"], |
|
|
"test_mae_std": test_mae["std"], |
|
|
"test_mae_ci_low": test_mae["ci_low"], |
|
|
"test_mae_ci_high": test_mae["ci_high"], |
|
|
} |
|
|
) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
df = parse_txt(args.input) |
|
|
|
|
|
|
|
|
if args.sort_by: |
|
|
asc = ( |
|
|
[True] * len(args.sort_by) |
|
|
if args.ascending is None |
|
|
else [bool(int(a)) for a in args.ascending] |
|
|
) |
|
|
df = df.sort_values(by=args.sort_by, ascending=asc) |
|
|
if args.dataset_order: |
|
|
cat = pd.Categorical(df["dataset"], categories=args.dataset_order, ordered=True) |
|
|
df = df.assign(_dataset=cat).sort_values("_dataset").drop(columns="_dataset") |
|
|
|
|
|
|
|
|
rename_map = dict(kv.split("=", 1) for kv in args.rename) if args.rename else {} |
|
|
bold_re = ( |
|
|
re.compile(args.bold_contains) if args.bold_contains else None |
|
|
) |
|
|
|
|
|
bold_re = re.compile(args.bold_contains) if args.bold_contains else None |
|
|
|
|
|
|
|
|
if args.test_intervals == "ci": |
|
|
colspec = ( |
|
|
"@{}ll " |
|
|
+ " ".join( |
|
|
[ |
|
|
f"S[table-format={args.table_formats[0]}]", |
|
|
f"S[table-format={args.table_formats[1]}]", |
|
|
"l", |
|
|
"l", |
|
|
] |
|
|
) |
|
|
+ "@{}" |
|
|
) |
|
|
else: |
|
|
colspec = ( |
|
|
"@{}ll " |
|
|
+ " ".join([f"S[table-format={tf}]" for tf in args.table_formats]) |
|
|
+ "@{}" |
|
|
) |
|
|
|
|
|
|
|
|
lines = [] |
|
|
lines.append(r"\begin{table}[h]") |
|
|
lines.append(r"\centering") |
|
|
if args.font_size: |
|
|
lines.append(f"{args.font_size} % Font size") |
|
|
lines.append(r"\caption{" + args.caption + r"}") |
|
|
lines.append(r"\label{" + args.label + r"}") |
|
|
lines.append(r"% siunitx settings") |
|
|
lines.append(r"\sisetup{separate-uncertainty, table-align-text-post=false}") |
|
|
|
|
|
inner_begin = r"\begin{tabular}{" + colspec + r"}" |
|
|
inner_end = r"\end{tabular}" |
|
|
|
|
|
if args.no_resize: |
|
|
lines.append(inner_begin) |
|
|
else: |
|
|
lines.append(r"\resizebox{" + args.width + r"}{!}{" + inner_begin) |
|
|
|
|
|
if args.booktabs: |
|
|
lines.append(r"\toprule") |
|
|
lines.append( |
|
|
r"\textbf{Dataset} & \textbf{Model (Rep.)} & {Val RMSE} & {Val MAE} & {Test RMSE} & {Test MAE} \\" |
|
|
) |
|
|
if args.booktabs: |
|
|
lines.append(r"\midrule") |
|
|
|
|
|
|
|
|
for dataset, g in df.groupby("dataset", sort=False): |
|
|
n = len(g) |
|
|
first = True |
|
|
for _, row in g.iterrows(): |
|
|
|
|
|
cell_model = apply_renames(row["label"], rename_map) |
|
|
do_bold = bool(bold_re and bold_re.search(cell_model)) if bold_re else False |
|
|
|
|
|
|
|
|
val_rmse = fmt_unc( |
|
|
row["val_rmse_mean"], |
|
|
row["val_rmse_std"], |
|
|
fixed_decimals=not args.no_fixed_decimals, |
|
|
dec_places=args.val_dec, |
|
|
) |
|
|
val_mae = fmt_unc( |
|
|
row["val_mae_mean"], |
|
|
row["val_mae_std"], |
|
|
fixed_decimals=not args.no_fixed_decimals, |
|
|
dec_places=args.val_dec, |
|
|
) |
|
|
|
|
|
|
|
|
def ci_or_pm(mean, std, lo, hi): |
|
|
if args.test_intervals == "ci" and (lo is not None and hi is not None): |
|
|
return f"{mean} [{lo}, {hi}]" |
|
|
if std is None and (lo is not None and hi is not None): |
|
|
std = (hi - lo) / 2.0 |
|
|
return fmt_unc( |
|
|
mean, |
|
|
std, |
|
|
fixed_decimals=not args.no_fixed_decimals, |
|
|
dec_places=args.test_dec, |
|
|
) |
|
|
|
|
|
test_rmse = ci_or_pm( |
|
|
row["test_rmse_mean"], |
|
|
row["test_rmse_std"], |
|
|
row["test_rmse_ci_low"], |
|
|
row["test_rmse_ci_high"], |
|
|
) |
|
|
test_mae = ci_or_pm( |
|
|
row["test_mae_mean"], |
|
|
row["test_mae_std"], |
|
|
row["test_mae_ci_low"], |
|
|
row["test_mae_ci_high"], |
|
|
) |
|
|
|
|
|
parts = [] |
|
|
if first: |
|
|
parts.append(rf"\multirow{{{n}}}{{*}}{{{dataset}}}") |
|
|
first = False |
|
|
else: |
|
|
parts.append("") |
|
|
|
|
|
if do_bold: |
|
|
parts.append(rf"\bfseries {cell_model}") |
|
|
parts.append(rf"\bfseries {val_rmse}") |
|
|
parts.append(rf"\bfseries {val_mae}") |
|
|
parts.append(rf"\bfseries {test_rmse}") |
|
|
parts.append(rf"\bfseries {test_mae}") |
|
|
else: |
|
|
parts.append(cell_model) |
|
|
parts.append(val_rmse) |
|
|
parts.append(val_mae) |
|
|
parts.append(test_rmse) |
|
|
parts.append(test_mae) |
|
|
|
|
|
lines.append(" & ".join(parts) + r" \\") |
|
|
|
|
|
if args.booktabs: |
|
|
lines.append(r"\bottomrule") |
|
|
lines.append(inner_end) |
|
|
if not args.no_resize: |
|
|
lines.append("}") |
|
|
lines.append(r"\end{table}") |
|
|
|
|
|
args.output.write_text("\n".join(lines), encoding="utf-8") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|