|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
from data.data_handling import prepare_and_load_data, GP_FEATURIZERS |
|
|
from training.training_pipeline import k_fold_tuned_eval |
|
|
from models.gp import TanimotoGP as GPModel |
|
|
from data.loaders import load_dataset |
|
|
|
|
|
from training.train_eval import ( |
|
|
k_fold_eval as gp_k_fold_eval, |
|
|
train_gp_model, |
|
|
eval_gp_model, |
|
|
) |
|
|
|
|
|
|
|
|
def run_gnn_experiment(args): |
|
|
""" |
|
|
Orchestrates a full, tuned GNN experiment using the EFFICIENT pipeline. |
|
|
""" |
|
|
print( |
|
|
f"--- Starting GNN Experiment: {args.model} | {args.rep} | {args.dataset} ---" |
|
|
) |
|
|
|
|
|
|
|
|
train_graphs, test_graphs = prepare_and_load_data(args) |
|
|
|
|
|
|
|
|
k_fold_tuned_eval(args, train_graphs, test_graphs) |
|
|
|
|
|
print(f"---GNN Experiment Finished. Results saved---") |
|
|
|
|
|
|
|
|
def run_gp_experiment(args): |
|
|
"""Orchestrates a GP experiment, preserving the original logic.""" |
|
|
print(f"--- Starting GP Experiment: {args.rep} | {args.dataset} ---") |
|
|
|
|
|
X_train, X_test, y_train, y_test = load_dataset(args.dataset) |
|
|
gp_feat = GP_FEATURIZERS[args.rep] |
|
|
|
|
|
X_train_feat = np.stack([gp_feat(x) for x in X_train]).astype(np.float32) |
|
|
X_test_feat = np.stack([gp_feat(x) for x in X_test]).astype(np.float32) |
|
|
|
|
|
def train_fn(X_tr, y_tr, log_file): |
|
|
model = GPModel() |
|
|
return train_gp_model(model, X_tr, y_tr, log_file) |
|
|
|
|
|
def eval_fn(model, X_te, y_te, log_file, scaler, return_preds=False): |
|
|
return eval_gp_model( |
|
|
model, X_te, y_te, log_file, scaler, return_preds=return_preds |
|
|
) |
|
|
|
|
|
_, test_metrics = gp_k_fold_eval( |
|
|
train_fn=train_fn, |
|
|
eval_fn=eval_fn, |
|
|
X_train=X_train_feat, |
|
|
y_train=y_train, |
|
|
model_name="gp", |
|
|
rep_name=args.rep, |
|
|
dataset_name=args.dataset, |
|
|
X_test=X_test_feat, |
|
|
y_test=y_test, |
|
|
) |
|
|
print(f"--- GP Experiment Finished. Final Test Metrics: {test_metrics} ---") |
|
|
|