import torch import torch.nn as nn from sklearn.model_selection import KFold, train_test_split from sklearn.preprocessing import StandardScaler import numpy as np import os import tempfile from pathlib import Path from datetime import datetime import optuna import pandas as pd from torch_geometric.loader import DataLoader import copy from data.data_handling import get_model_instance ROOT = Path(__file__).parent.parent.resolve().__str__() LOG_ROOT = Path(ROOT + "/" + "logs_hyperparameter") if not os.path.exists(LOG_ROOT): os.makedirs(LOG_ROOT, exist_ok=False) def setup_log_file(args): from pathlib import Path from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_name, rep_name, dataset_name = args.model, args.rep, args.dataset fname = f"{model_name}_{rep_name}_{dataset_name}_{timestamp}.txt" parent = Path(__file__).parent.parent.resolve().__str__() log_dir = Path(parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}") if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=False) log_path = log_dir / fname print(f"[Logging] Writing to: {log_path}") return log_path def write_log(log_file_path, text): """Writes a message to both console and the log file.""" print(text) with open(log_file_path, "a") as f: f.write(text + "\n") def train_gnn_model( model, train_loader, val_loader, optimizer, device, loss_fn, max_epochs=200, patience=20, ): """Trains a GNN with early stopping based on a validation set.""" best_val_loss = float("inf") epochs_no_improve = 0 temp_dir = tempfile.gettempdir() best_model_path = os.path.join(temp_dir, f"best_model_{os.getpid()}.pt") for _ in range(max_epochs): model.train() for batch in train_loader: batch = batch.to(device) if batch.num_edges == 0: continue optimizer.zero_grad() out = model(batch).view(-1) loss = loss_fn(out, batch.y.view(-1)) loss.backward() optimizer.step() model.eval() val_loss = 0 with torch.no_grad(): for batch in val_loader: batch = batch.to(device) val_loss += loss_fn(model(batch).view(-1), batch.y.view(-1)).item() if len(val_loader) > 0: avg_val_loss = val_loss / len(val_loader) else: avg_val_loss = float("inf") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss torch.save(model.state_dict(), best_model_path) epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= patience: break if os.path.exists(best_model_path): model.load_state_dict(torch.load(best_model_path)) os.remove(best_model_path) return model def objective(trial, args, train_graphs, val_graphs, device, scaler): """Optuna objective function. Uses a pre-fitted scaler for consistency.""" train_graphs = copy.deepcopy(train_graphs) val_graphs = copy.deepcopy(val_graphs) params = { "lr": trial.suggest_float("lr", 5e-4, 1e-3, log=True), "hidden_dim": trial.suggest_categorical("hidden_dim", [64, 128, 256]), "batch_size": trial.suggest_categorical("batch_size", [32, 64]), } # Use the provided scaler, do not re-fit for g in train_graphs: g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) for g in val_graphs: g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) train_loader = DataLoader( train_graphs, batch_size=params["batch_size"], shuffle=True ) val_loader = DataLoader(val_graphs, batch_size=params["batch_size"]) model = get_model_instance(args, params, train_graphs).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) train_gnn_model( model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss() ) model.eval() val_loss = 0 with torch.no_grad(): for batch in val_loader: batch = batch.to(device) val_loss += nn.MSELoss()(model(batch).view(-1), batch.y.view(-1)).item() return val_loss / len(val_loader) if len(val_loader) > 0 else float("inf") def find_best_hyperparameters(args, train_val_graphs, device, scaler): """ Runs an Optuna study. """ train_graphs, val_graphs = train_test_split( train_val_graphs, test_size=0.2, random_state=42 ) study = optuna.create_study( direction="minimize", sampler=optuna.samplers.TPESampler(seed=42) ) study.optimize( lambda trial: objective(trial, args, train_graphs, val_graphs, device, scaler), n_trials=args.n_trials, show_progress_bar=True, ) return study.best_params def bootstrap_metric(y_true, y_pred, metric_func, n_bootstraps=1000): """Performs bootstrapping to estimate the confidence interval of a metric.""" n_samples = len(y_true) bootstrapped_scores = [] for _ in range(n_bootstraps): indices = np.random.choice(n_samples, n_samples, replace=True) score = metric_func(y_true[indices], y_pred[indices]) bootstrapped_scores.append(score) lower_bound = np.percentile(bootstrapped_scores, 2.5) upper_bound = np.percentile(bootstrapped_scores, 97.5) mean_score = np.mean(bootstrapped_scores) return mean_score, lower_bound, upper_bound def rmse_func(y_true, y_pred): return np.sqrt(np.mean((y_true - y_pred) ** 2)) def mae_func(y_true, y_pred): return np.mean(np.abs(y_true - y_pred)) def k_fold_tuned_eval(args, train_graphs_full, test_graphs): """ Orchestrates a rigorous NESTED cross-validation workflow. """ log_file_path = setup_log_file(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") outer_kf = KFold(n_splits=5, shuffle=True, random_state=42) val_fold_rmses = [] val_fold_maes = [] train_indices = np.arange(len(train_graphs_full)) for fold, (train_idx, val_idx) in enumerate(outer_kf.split(train_indices)): write_log(log_file_path, f"\n--- OUTER FOLD {fold + 1}/5 ---") train_fold_graphs = [train_graphs_full[i] for i in train_idx] val_fold_graphs = [train_graphs_full[i] for i in val_idx] y_train_fold_raw = np.array([g.y.item() for g in train_fold_graphs]).reshape( -1, 1 ) scaler = StandardScaler().fit(y_train_fold_raw) best_params_for_fold = find_best_hyperparameters( args, train_fold_graphs, device, scaler ) write_log( log_file_path, f"INFO: Best params for fold {fold + 1}: {best_params_for_fold}", ) train_fold_graphs_scaled = [g.clone() for g in train_fold_graphs] val_fold_graphs_scaled = [g.clone() for g in val_fold_graphs] for g in train_fold_graphs_scaled: g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) for g in val_fold_graphs_scaled: g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) train_loader = DataLoader( train_fold_graphs_scaled, batch_size=best_params_for_fold["batch_size"], shuffle=True, ) val_loader = DataLoader( val_fold_graphs_scaled, batch_size=best_params_for_fold["batch_size"] ) model = get_model_instance(args, best_params_for_fold, train_fold_graphs).to( device ) optimizer = torch.optim.Adam(model.parameters(), lr=best_params_for_fold["lr"]) train_gnn_model( model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss() ) y_true_val, y_pred_val = [], [] model.eval() with torch.no_grad(): for batch in val_loader: batch = batch.to(device) out = model(batch).view(-1) y_true_val.extend( scaler.inverse_transform( batch.y.cpu().numpy().reshape(-1, 1) ).ravel() ) y_pred_val.extend( scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel() ) fold_rmse = rmse_func(np.array(y_true_val), np.array(y_pred_val)) fold_mae = mae_func(np.array(y_true_val), np.array(y_pred_val)) val_fold_rmses.append(fold_rmse) val_fold_maes.append(fold_mae) write_log( log_file_path, f"INFO: Fold {fold + 1} Val RMSE: {fold_rmse:.4f}, MAE: {fold_mae:.4f}", ) mean_val_rmse = np.mean(val_fold_rmses) std_val_rmse = np.std(val_fold_rmses) mean_val_mae = np.mean(val_fold_maes) std_val_mae = np.std(val_fold_maes) write_log(log_file_path, "\n------ Nested Cross-Validation Summary ------") write_log( log_file_path, f"Unbiased Validation RMSE: {mean_val_rmse:.4f} ± {std_val_rmse:.4f}", ) write_log( log_file_path, f"Unbiased Validation MAE: {mean_val_mae:.4f} ± {std_val_mae:.4f}", ) write_log(log_file_path, f"VAL FOLD RMSEs: {val_fold_rmses}") write_log(log_file_path, f"VAL FOLD MAEs: {val_fold_maes}") write_log(log_file_path, "\n===== STEP 2: Final Model Training & Testing =====") write_log( log_file_path, "INFO: Finding best hyperparameters on the FULL train/val set for final model...", ) final_y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape( -1, 1 ) final_hpo_scaler = StandardScaler().fit(final_y_train_full_raw) final_best_params = find_best_hyperparameters( args, train_graphs_full, device, final_hpo_scaler ) write_log( log_file_path, f"INFO: Optimal hyperparameters for final model: {final_best_params}", ) write_log(log_file_path, "INFO: Training final model...") y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape(-1, 1) final_scaler = StandardScaler().fit(y_train_full_raw) final_train_graphs = [g.clone() for g in train_graphs_full] for g in final_train_graphs: g.y = torch.tensor( final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float ) train_subset, val_subset = train_test_split( final_train_graphs, test_size=0.1, random_state=42 ) final_train_loader = DataLoader( train_subset, batch_size=final_best_params["batch_size"], shuffle=True ) final_val_loader = DataLoader( val_subset, batch_size=final_best_params["batch_size"] ) final_model = get_model_instance(args, final_best_params, final_train_graphs).to( device ) final_optimizer = torch.optim.Adam( final_model.parameters(), lr=final_best_params["lr"] ) train_gnn_model( final_model, final_train_loader, final_val_loader, final_optimizer, device, loss_fn=nn.MSELoss(), ) write_log(log_file_path, "\n===== STEP 3: Final Held-Out Test Evaluation =====") final_test_graphs = [g.clone() for g in test_graphs] for g in final_test_graphs: g.y = torch.tensor( final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float ) test_loader = DataLoader( final_test_graphs, batch_size=final_best_params["batch_size"] ) y_true_test, y_pred_test = [], [] final_model.eval() with torch.no_grad(): for batch in test_loader: batch = batch.to(device) out = final_model(batch).view(-1) y_true_test.extend( final_scaler.inverse_transform( batch.y.cpu().numpy().reshape(-1, 1) ).ravel() ) y_pred_test.extend( final_scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel() ) y_true_test, y_pred_test = np.array(y_true_test), np.array(y_pred_test) rmse_mean, rmse_low, rmse_high = bootstrap_metric( y_true_test, y_pred_test, rmse_func ) mae_mean, mae_low, mae_high = bootstrap_metric(y_true_test, y_pred_test, mae_func) write_log( log_file_path, f"Test RMSE: {rmse_mean:.4f} (95% CI: [{rmse_low:.4f}, {rmse_high:.4f}])", ) write_log( log_file_path, f"Test MAE: {mae_mean:.4f} (95% CI: [{mae_low:.4f}, {mae_high:.4f}])", ) results_data = { "val_rmse_mean": [mean_val_rmse], "val_rmse_std": [std_val_rmse], "val_mae_mean": [mean_val_mae], "val_mae_std": [std_val_mae], "test_rmse_mean": [rmse_mean], "test_rmse_ci_low": [rmse_low], "test_rmse_ci_high": [rmse_high], "test_mae_mean": [mae_mean], "test_mae_ci_low": [mae_low], "test_mae_ci_high": [mae_high], } parent = Path(__file__).parent.parent.resolve().__str__() log_dir = Path( parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}" ).__str__() data_res_path = ( log_dir + "/" + f"{args.model}_{args.rep}_{args.dataset}_final_results.csv" ) pd.DataFrame(results_data).to_csv(data_res_path, index=False)