|
import argparse |
|
|
|
|
|
from runners import run_gnn_experiment, run_gp_experiment |
|
|
|
|
|
def main(): |
|
""" |
|
Main entry point for running all experiments. |
|
Parses command-line arguments and calls the appropriate runner function. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description="Run GNN or GP experiments for molecular property prediction." |
|
) |
|
parser.add_argument( |
|
"--model", |
|
choices=["gcn", "gin", "gat", "sage", "gp", "polyatomic"], |
|
required=True, |
|
help="The model to train and evaluate.", |
|
) |
|
parser.add_argument( |
|
"--rep", |
|
choices=["smiles", "selfies", "ecfp", "polyatomic"], |
|
required=True, |
|
help="The molecular representation to use.", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
choices=[ |
|
"esol", |
|
"freesolv", |
|
"lipophil", |
|
"boilingpoint", |
|
"qm9", |
|
"ic50", |
|
"bindingdb", |
|
], |
|
required=True, |
|
help="The dataset to use for the experiment.", |
|
) |
|
parser.add_argument( |
|
"--n-trials", |
|
type=int, |
|
default=10, |
|
help="Number of Optuna trials to run for hyperparameter search in each fold.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.model == "gp" and args.rep == "polyatomic": |
|
raise ValueError( |
|
"The 'polyatomic' representation is not compatible with the 'gp' model." |
|
) |
|
if args.model == "polyatomic" and args.rep != "polyatomic": |
|
raise ValueError( |
|
"The 'polyatomic' model must be used with the 'polyatomic' representation." |
|
) |
|
|
|
|
|
if args.model == "gp": |
|
run_gp_experiment(args) |
|
else: |
|
run_gnn_experiment(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|