|
|
import torch |
|
|
import torch.nn as nn |
|
|
from sklearn.model_selection import KFold, train_test_split |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
import numpy as np |
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import optuna |
|
|
import pandas as pd |
|
|
from torch_geometric.loader import DataLoader |
|
|
import copy |
|
|
|
|
|
from data.data_handling import get_model_instance |
|
|
|
|
|
ROOT = Path(__file__).parent.parent.resolve().__str__() |
|
|
LOG_ROOT = Path(ROOT + "/" + "logs_hyperparameter") |
|
|
if not os.path.exists(LOG_ROOT): |
|
|
os.makedirs(LOG_ROOT, exist_ok=False) |
|
|
|
|
|
|
|
|
def setup_log_file(args): |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
model_name, rep_name, dataset_name = args.model, args.rep, args.dataset |
|
|
fname = f"{model_name}_{rep_name}_{dataset_name}_{timestamp}.txt" |
|
|
parent = Path(__file__).parent.parent.resolve().__str__() |
|
|
log_dir = Path(parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}") |
|
|
if not os.path.exists(log_dir): |
|
|
os.makedirs(log_dir, exist_ok=False) |
|
|
|
|
|
log_path = log_dir / fname |
|
|
print(f"[Logging] Writing to: {log_path}") |
|
|
return log_path |
|
|
|
|
|
|
|
|
def write_log(log_file_path, text): |
|
|
"""Writes a message to both console and the log file.""" |
|
|
print(text) |
|
|
with open(log_file_path, "a") as f: |
|
|
f.write(text + "\n") |
|
|
|
|
|
|
|
|
def train_gnn_model( |
|
|
model, |
|
|
train_loader, |
|
|
val_loader, |
|
|
optimizer, |
|
|
device, |
|
|
loss_fn, |
|
|
max_epochs=200, |
|
|
patience=20, |
|
|
): |
|
|
"""Trains a GNN with early stopping based on a validation set.""" |
|
|
best_val_loss = float("inf") |
|
|
epochs_no_improve = 0 |
|
|
temp_dir = tempfile.gettempdir() |
|
|
best_model_path = os.path.join(temp_dir, f"best_model_{os.getpid()}.pt") |
|
|
|
|
|
for _ in range(max_epochs): |
|
|
model.train() |
|
|
for batch in train_loader: |
|
|
batch = batch.to(device) |
|
|
if batch.num_edges == 0: |
|
|
continue |
|
|
optimizer.zero_grad() |
|
|
out = model(batch).view(-1) |
|
|
loss = loss_fn(out, batch.y.view(-1)) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
model.eval() |
|
|
val_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
batch = batch.to(device) |
|
|
val_loss += loss_fn(model(batch).view(-1), batch.y.view(-1)).item() |
|
|
|
|
|
if len(val_loader) > 0: |
|
|
avg_val_loss = val_loss / len(val_loader) |
|
|
else: |
|
|
avg_val_loss = float("inf") |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
|
best_val_loss = avg_val_loss |
|
|
torch.save(model.state_dict(), best_model_path) |
|
|
epochs_no_improve = 0 |
|
|
else: |
|
|
epochs_no_improve += 1 |
|
|
if epochs_no_improve >= patience: |
|
|
break |
|
|
|
|
|
if os.path.exists(best_model_path): |
|
|
model.load_state_dict(torch.load(best_model_path)) |
|
|
os.remove(best_model_path) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def objective(trial, args, train_graphs, val_graphs, device, scaler): |
|
|
"""Optuna objective function. Uses a pre-fitted scaler for consistency.""" |
|
|
|
|
|
train_graphs = copy.deepcopy(train_graphs) |
|
|
val_graphs = copy.deepcopy(val_graphs) |
|
|
|
|
|
params = { |
|
|
"lr": trial.suggest_float("lr", 5e-4, 1e-3, log=True), |
|
|
"hidden_dim": trial.suggest_categorical("hidden_dim", [64, 128, 256]), |
|
|
"batch_size": trial.suggest_categorical("batch_size", [32, 64]), |
|
|
} |
|
|
|
|
|
|
|
|
for g in train_graphs: |
|
|
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) |
|
|
for g in val_graphs: |
|
|
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_graphs, batch_size=params["batch_size"], shuffle=True |
|
|
) |
|
|
val_loader = DataLoader(val_graphs, batch_size=params["batch_size"]) |
|
|
|
|
|
model = get_model_instance(args, params, train_graphs).to(device) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) |
|
|
train_gnn_model( |
|
|
model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss() |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
val_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
batch = batch.to(device) |
|
|
val_loss += nn.MSELoss()(model(batch).view(-1), batch.y.view(-1)).item() |
|
|
|
|
|
return val_loss / len(val_loader) if len(val_loader) > 0 else float("inf") |
|
|
|
|
|
|
|
|
def find_best_hyperparameters(args, train_val_graphs, device, scaler): |
|
|
""" |
|
|
Runs an Optuna study. |
|
|
""" |
|
|
train_graphs, val_graphs = train_test_split( |
|
|
train_val_graphs, test_size=0.2, random_state=42 |
|
|
) |
|
|
|
|
|
study = optuna.create_study( |
|
|
direction="minimize", sampler=optuna.samplers.TPESampler(seed=42) |
|
|
) |
|
|
|
|
|
study.optimize( |
|
|
lambda trial: objective(trial, args, train_graphs, val_graphs, device, scaler), |
|
|
n_trials=args.n_trials, |
|
|
show_progress_bar=True, |
|
|
) |
|
|
return study.best_params |
|
|
|
|
|
|
|
|
def bootstrap_metric(y_true, y_pred, metric_func, n_bootstraps=1000): |
|
|
"""Performs bootstrapping to estimate the confidence interval of a metric.""" |
|
|
n_samples = len(y_true) |
|
|
bootstrapped_scores = [] |
|
|
for _ in range(n_bootstraps): |
|
|
indices = np.random.choice(n_samples, n_samples, replace=True) |
|
|
score = metric_func(y_true[indices], y_pred[indices]) |
|
|
bootstrapped_scores.append(score) |
|
|
|
|
|
lower_bound = np.percentile(bootstrapped_scores, 2.5) |
|
|
upper_bound = np.percentile(bootstrapped_scores, 97.5) |
|
|
mean_score = np.mean(bootstrapped_scores) |
|
|
|
|
|
return mean_score, lower_bound, upper_bound |
|
|
|
|
|
|
|
|
def rmse_func(y_true, y_pred): |
|
|
return np.sqrt(np.mean((y_true - y_pred) ** 2)) |
|
|
|
|
|
|
|
|
def mae_func(y_true, y_pred): |
|
|
return np.mean(np.abs(y_true - y_pred)) |
|
|
|
|
|
|
|
|
def k_fold_tuned_eval(args, train_graphs_full, test_graphs): |
|
|
""" |
|
|
Orchestrates a rigorous NESTED cross-validation workflow. |
|
|
""" |
|
|
log_file_path = setup_log_file(args) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
outer_kf = KFold(n_splits=5, shuffle=True, random_state=42) |
|
|
val_fold_rmses = [] |
|
|
val_fold_maes = [] |
|
|
|
|
|
train_indices = np.arange(len(train_graphs_full)) |
|
|
|
|
|
for fold, (train_idx, val_idx) in enumerate(outer_kf.split(train_indices)): |
|
|
write_log(log_file_path, f"\n--- OUTER FOLD {fold + 1}/5 ---") |
|
|
|
|
|
train_fold_graphs = [train_graphs_full[i] for i in train_idx] |
|
|
val_fold_graphs = [train_graphs_full[i] for i in val_idx] |
|
|
|
|
|
y_train_fold_raw = np.array([g.y.item() for g in train_fold_graphs]).reshape( |
|
|
-1, 1 |
|
|
) |
|
|
scaler = StandardScaler().fit(y_train_fold_raw) |
|
|
|
|
|
best_params_for_fold = find_best_hyperparameters( |
|
|
args, train_fold_graphs, device, scaler |
|
|
) |
|
|
write_log( |
|
|
log_file_path, |
|
|
f"INFO: Best params for fold {fold + 1}: {best_params_for_fold}", |
|
|
) |
|
|
|
|
|
train_fold_graphs_scaled = [g.clone() for g in train_fold_graphs] |
|
|
val_fold_graphs_scaled = [g.clone() for g in val_fold_graphs] |
|
|
for g in train_fold_graphs_scaled: |
|
|
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) |
|
|
for g in val_fold_graphs_scaled: |
|
|
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_fold_graphs_scaled, |
|
|
batch_size=best_params_for_fold["batch_size"], |
|
|
shuffle=True, |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_fold_graphs_scaled, batch_size=best_params_for_fold["batch_size"] |
|
|
) |
|
|
|
|
|
model = get_model_instance(args, best_params_for_fold, train_fold_graphs).to( |
|
|
device |
|
|
) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=best_params_for_fold["lr"]) |
|
|
train_gnn_model( |
|
|
model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss() |
|
|
) |
|
|
|
|
|
y_true_val, y_pred_val = [], [] |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
batch = batch.to(device) |
|
|
out = model(batch).view(-1) |
|
|
y_true_val.extend( |
|
|
scaler.inverse_transform( |
|
|
batch.y.cpu().numpy().reshape(-1, 1) |
|
|
).ravel() |
|
|
) |
|
|
y_pred_val.extend( |
|
|
scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel() |
|
|
) |
|
|
|
|
|
fold_rmse = rmse_func(np.array(y_true_val), np.array(y_pred_val)) |
|
|
fold_mae = mae_func(np.array(y_true_val), np.array(y_pred_val)) |
|
|
val_fold_rmses.append(fold_rmse) |
|
|
val_fold_maes.append(fold_mae) |
|
|
write_log( |
|
|
log_file_path, |
|
|
f"INFO: Fold {fold + 1} Val RMSE: {fold_rmse:.4f}, MAE: {fold_mae:.4f}", |
|
|
) |
|
|
|
|
|
mean_val_rmse = np.mean(val_fold_rmses) |
|
|
std_val_rmse = np.std(val_fold_rmses) |
|
|
mean_val_mae = np.mean(val_fold_maes) |
|
|
std_val_mae = np.std(val_fold_maes) |
|
|
write_log(log_file_path, "\n------ Nested Cross-Validation Summary ------") |
|
|
write_log( |
|
|
log_file_path, |
|
|
f"Unbiased Validation RMSE: {mean_val_rmse:.4f} ± {std_val_rmse:.4f}", |
|
|
) |
|
|
write_log( |
|
|
log_file_path, |
|
|
f"Unbiased Validation MAE: {mean_val_mae:.4f} ± {std_val_mae:.4f}", |
|
|
) |
|
|
write_log(log_file_path, f"VAL FOLD RMSEs: {val_fold_rmses}") |
|
|
write_log(log_file_path, f"VAL FOLD MAEs: {val_fold_maes}") |
|
|
|
|
|
write_log(log_file_path, "\n===== STEP 2: Final Model Training & Testing =====") |
|
|
write_log( |
|
|
log_file_path, |
|
|
"INFO: Finding best hyperparameters on the FULL train/val set for final model...", |
|
|
) |
|
|
|
|
|
final_y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape( |
|
|
-1, 1 |
|
|
) |
|
|
final_hpo_scaler = StandardScaler().fit(final_y_train_full_raw) |
|
|
final_best_params = find_best_hyperparameters( |
|
|
args, train_graphs_full, device, final_hpo_scaler |
|
|
) |
|
|
|
|
|
write_log( |
|
|
log_file_path, |
|
|
f"INFO: Optimal hyperparameters for final model: {final_best_params}", |
|
|
) |
|
|
|
|
|
write_log(log_file_path, "INFO: Training final model...") |
|
|
y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape(-1, 1) |
|
|
final_scaler = StandardScaler().fit(y_train_full_raw) |
|
|
|
|
|
final_train_graphs = [g.clone() for g in train_graphs_full] |
|
|
for g in final_train_graphs: |
|
|
g.y = torch.tensor( |
|
|
final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float |
|
|
) |
|
|
|
|
|
train_subset, val_subset = train_test_split( |
|
|
final_train_graphs, test_size=0.1, random_state=42 |
|
|
) |
|
|
final_train_loader = DataLoader( |
|
|
train_subset, batch_size=final_best_params["batch_size"], shuffle=True |
|
|
) |
|
|
final_val_loader = DataLoader( |
|
|
val_subset, batch_size=final_best_params["batch_size"] |
|
|
) |
|
|
|
|
|
final_model = get_model_instance(args, final_best_params, final_train_graphs).to( |
|
|
device |
|
|
) |
|
|
final_optimizer = torch.optim.Adam( |
|
|
final_model.parameters(), lr=final_best_params["lr"] |
|
|
) |
|
|
train_gnn_model( |
|
|
final_model, |
|
|
final_train_loader, |
|
|
final_val_loader, |
|
|
final_optimizer, |
|
|
device, |
|
|
loss_fn=nn.MSELoss(), |
|
|
) |
|
|
|
|
|
write_log(log_file_path, "\n===== STEP 3: Final Held-Out Test Evaluation =====") |
|
|
final_test_graphs = [g.clone() for g in test_graphs] |
|
|
for g in final_test_graphs: |
|
|
g.y = torch.tensor( |
|
|
final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
final_test_graphs, batch_size=final_best_params["batch_size"] |
|
|
) |
|
|
|
|
|
y_true_test, y_pred_test = [], [] |
|
|
final_model.eval() |
|
|
with torch.no_grad(): |
|
|
for batch in test_loader: |
|
|
batch = batch.to(device) |
|
|
out = final_model(batch).view(-1) |
|
|
y_true_test.extend( |
|
|
final_scaler.inverse_transform( |
|
|
batch.y.cpu().numpy().reshape(-1, 1) |
|
|
).ravel() |
|
|
) |
|
|
y_pred_test.extend( |
|
|
final_scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel() |
|
|
) |
|
|
|
|
|
y_true_test, y_pred_test = np.array(y_true_test), np.array(y_pred_test) |
|
|
|
|
|
rmse_mean, rmse_low, rmse_high = bootstrap_metric( |
|
|
y_true_test, y_pred_test, rmse_func |
|
|
) |
|
|
mae_mean, mae_low, mae_high = bootstrap_metric(y_true_test, y_pred_test, mae_func) |
|
|
|
|
|
write_log( |
|
|
log_file_path, |
|
|
f"Test RMSE: {rmse_mean:.4f} (95% CI: [{rmse_low:.4f}, {rmse_high:.4f}])", |
|
|
) |
|
|
write_log( |
|
|
log_file_path, |
|
|
f"Test MAE: {mae_mean:.4f} (95% CI: [{mae_low:.4f}, {mae_high:.4f}])", |
|
|
) |
|
|
|
|
|
results_data = { |
|
|
"val_rmse_mean": [mean_val_rmse], |
|
|
"val_rmse_std": [std_val_rmse], |
|
|
"val_mae_mean": [mean_val_mae], |
|
|
"val_mae_std": [std_val_mae], |
|
|
"test_rmse_mean": [rmse_mean], |
|
|
"test_rmse_ci_low": [rmse_low], |
|
|
"test_rmse_ci_high": [rmse_high], |
|
|
"test_mae_mean": [mae_mean], |
|
|
"test_mae_ci_low": [mae_low], |
|
|
"test_mae_ci_high": [mae_high], |
|
|
} |
|
|
parent = Path(__file__).parent.parent.resolve().__str__() |
|
|
log_dir = Path( |
|
|
parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}" |
|
|
).__str__() |
|
|
data_res_path = ( |
|
|
log_dir + "/" + f"{args.model}_{args.rep}_{args.dataset}_final_results.csv" |
|
|
) |
|
|
pd.DataFrame(results_data).to_csv(data_res_path, index=False) |
|
|
|