PACT-Net / training /training_pipeline.py
rk-random's picture
Upload folder using huggingface_hub
9a67fbe verified
raw
history blame
13.5 kB
import torch
import torch.nn as nn
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
import os
import tempfile
from pathlib import Path
from datetime import datetime
import optuna
import pandas as pd
from torch_geometric.loader import DataLoader
import copy
from data.data_handling import get_model_instance
ROOT = Path(__file__).parent.parent.resolve().__str__()
LOG_ROOT = Path(ROOT + "/" + "logs_hyperparameter")
if not os.path.exists(LOG_ROOT):
os.makedirs(LOG_ROOT, exist_ok=False)
def setup_log_file(args):
from pathlib import Path
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name, rep_name, dataset_name = args.model, args.rep, args.dataset
fname = f"{model_name}_{rep_name}_{dataset_name}_{timestamp}.txt"
parent = Path(__file__).parent.parent.resolve().__str__()
log_dir = Path(parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}")
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=False)
log_path = log_dir / fname
print(f"[Logging] Writing to: {log_path}")
return log_path
def write_log(log_file_path, text):
"""Writes a message to both console and the log file."""
print(text)
with open(log_file_path, "a") as f:
f.write(text + "\n")
def train_gnn_model(
model,
train_loader,
val_loader,
optimizer,
device,
loss_fn,
max_epochs=200,
patience=20,
):
"""Trains a GNN with early stopping based on a validation set."""
best_val_loss = float("inf")
epochs_no_improve = 0
temp_dir = tempfile.gettempdir()
best_model_path = os.path.join(temp_dir, f"best_model_{os.getpid()}.pt")
for _ in range(max_epochs):
model.train()
for batch in train_loader:
batch = batch.to(device)
if batch.num_edges == 0:
continue
optimizer.zero_grad()
out = model(batch).view(-1)
loss = loss_fn(out, batch.y.view(-1))
loss.backward()
optimizer.step()
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
val_loss += loss_fn(model(batch).view(-1), batch.y.view(-1)).item()
if len(val_loader) > 0:
avg_val_loss = val_loss / len(val_loader)
else:
avg_val_loss = float("inf")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), best_model_path)
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
break
if os.path.exists(best_model_path):
model.load_state_dict(torch.load(best_model_path))
os.remove(best_model_path)
return model
def objective(trial, args, train_graphs, val_graphs, device, scaler):
"""Optuna objective function. Uses a pre-fitted scaler for consistency."""
train_graphs = copy.deepcopy(train_graphs)
val_graphs = copy.deepcopy(val_graphs)
params = {
"lr": trial.suggest_float("lr", 5e-4, 1e-3, log=True),
"hidden_dim": trial.suggest_categorical("hidden_dim", [64, 128, 256]),
"batch_size": trial.suggest_categorical("batch_size", [32, 64]),
}
# Use the provided scaler, do not re-fit
for g in train_graphs:
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float)
for g in val_graphs:
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float)
train_loader = DataLoader(
train_graphs, batch_size=params["batch_size"], shuffle=True
)
val_loader = DataLoader(val_graphs, batch_size=params["batch_size"])
model = get_model_instance(args, params, train_graphs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
train_gnn_model(
model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss()
)
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
val_loss += nn.MSELoss()(model(batch).view(-1), batch.y.view(-1)).item()
return val_loss / len(val_loader) if len(val_loader) > 0 else float("inf")
def find_best_hyperparameters(args, train_val_graphs, device, scaler):
"""
Runs an Optuna study.
"""
train_graphs, val_graphs = train_test_split(
train_val_graphs, test_size=0.2, random_state=42
)
study = optuna.create_study(
direction="minimize", sampler=optuna.samplers.TPESampler(seed=42)
)
study.optimize(
lambda trial: objective(trial, args, train_graphs, val_graphs, device, scaler),
n_trials=args.n_trials,
show_progress_bar=True,
)
return study.best_params
def bootstrap_metric(y_true, y_pred, metric_func, n_bootstraps=1000):
"""Performs bootstrapping to estimate the confidence interval of a metric."""
n_samples = len(y_true)
bootstrapped_scores = []
for _ in range(n_bootstraps):
indices = np.random.choice(n_samples, n_samples, replace=True)
score = metric_func(y_true[indices], y_pred[indices])
bootstrapped_scores.append(score)
lower_bound = np.percentile(bootstrapped_scores, 2.5)
upper_bound = np.percentile(bootstrapped_scores, 97.5)
mean_score = np.mean(bootstrapped_scores)
return mean_score, lower_bound, upper_bound
def rmse_func(y_true, y_pred):
return np.sqrt(np.mean((y_true - y_pred) ** 2))
def mae_func(y_true, y_pred):
return np.mean(np.abs(y_true - y_pred))
def k_fold_tuned_eval(args, train_graphs_full, test_graphs):
"""
Orchestrates a rigorous NESTED cross-validation workflow.
"""
log_file_path = setup_log_file(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
outer_kf = KFold(n_splits=5, shuffle=True, random_state=42)
val_fold_rmses = []
val_fold_maes = []
train_indices = np.arange(len(train_graphs_full))
for fold, (train_idx, val_idx) in enumerate(outer_kf.split(train_indices)):
write_log(log_file_path, f"\n--- OUTER FOLD {fold + 1}/5 ---")
train_fold_graphs = [train_graphs_full[i] for i in train_idx]
val_fold_graphs = [train_graphs_full[i] for i in val_idx]
y_train_fold_raw = np.array([g.y.item() for g in train_fold_graphs]).reshape(
-1, 1
)
scaler = StandardScaler().fit(y_train_fold_raw)
best_params_for_fold = find_best_hyperparameters(
args, train_fold_graphs, device, scaler
)
write_log(
log_file_path,
f"INFO: Best params for fold {fold + 1}: {best_params_for_fold}",
)
train_fold_graphs_scaled = [g.clone() for g in train_fold_graphs]
val_fold_graphs_scaled = [g.clone() for g in val_fold_graphs]
for g in train_fold_graphs_scaled:
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float)
for g in val_fold_graphs_scaled:
g.y = torch.tensor(scaler.transform(g.y.reshape(1, -1)), dtype=torch.float)
train_loader = DataLoader(
train_fold_graphs_scaled,
batch_size=best_params_for_fold["batch_size"],
shuffle=True,
)
val_loader = DataLoader(
val_fold_graphs_scaled, batch_size=best_params_for_fold["batch_size"]
)
model = get_model_instance(args, best_params_for_fold, train_fold_graphs).to(
device
)
optimizer = torch.optim.Adam(model.parameters(), lr=best_params_for_fold["lr"])
train_gnn_model(
model, train_loader, val_loader, optimizer, device, loss_fn=nn.MSELoss()
)
y_true_val, y_pred_val = [], []
model.eval()
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
out = model(batch).view(-1)
y_true_val.extend(
scaler.inverse_transform(
batch.y.cpu().numpy().reshape(-1, 1)
).ravel()
)
y_pred_val.extend(
scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel()
)
fold_rmse = rmse_func(np.array(y_true_val), np.array(y_pred_val))
fold_mae = mae_func(np.array(y_true_val), np.array(y_pred_val))
val_fold_rmses.append(fold_rmse)
val_fold_maes.append(fold_mae)
write_log(
log_file_path,
f"INFO: Fold {fold + 1} Val RMSE: {fold_rmse:.4f}, MAE: {fold_mae:.4f}",
)
mean_val_rmse = np.mean(val_fold_rmses)
std_val_rmse = np.std(val_fold_rmses)
mean_val_mae = np.mean(val_fold_maes)
std_val_mae = np.std(val_fold_maes)
write_log(log_file_path, "\n------ Nested Cross-Validation Summary ------")
write_log(
log_file_path,
f"Unbiased Validation RMSE: {mean_val_rmse:.4f} ± {std_val_rmse:.4f}",
)
write_log(
log_file_path,
f"Unbiased Validation MAE: {mean_val_mae:.4f} ± {std_val_mae:.4f}",
)
write_log(log_file_path, f"VAL FOLD RMSEs: {val_fold_rmses}")
write_log(log_file_path, f"VAL FOLD MAEs: {val_fold_maes}")
write_log(log_file_path, "\n===== STEP 2: Final Model Training & Testing =====")
write_log(
log_file_path,
"INFO: Finding best hyperparameters on the FULL train/val set for final model...",
)
final_y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape(
-1, 1
)
final_hpo_scaler = StandardScaler().fit(final_y_train_full_raw)
final_best_params = find_best_hyperparameters(
args, train_graphs_full, device, final_hpo_scaler
)
write_log(
log_file_path,
f"INFO: Optimal hyperparameters for final model: {final_best_params}",
)
write_log(log_file_path, "INFO: Training final model...")
y_train_full_raw = np.array([g.y.item() for g in train_graphs_full]).reshape(-1, 1)
final_scaler = StandardScaler().fit(y_train_full_raw)
final_train_graphs = [g.clone() for g in train_graphs_full]
for g in final_train_graphs:
g.y = torch.tensor(
final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float
)
train_subset, val_subset = train_test_split(
final_train_graphs, test_size=0.1, random_state=42
)
final_train_loader = DataLoader(
train_subset, batch_size=final_best_params["batch_size"], shuffle=True
)
final_val_loader = DataLoader(
val_subset, batch_size=final_best_params["batch_size"]
)
final_model = get_model_instance(args, final_best_params, final_train_graphs).to(
device
)
final_optimizer = torch.optim.Adam(
final_model.parameters(), lr=final_best_params["lr"]
)
train_gnn_model(
final_model,
final_train_loader,
final_val_loader,
final_optimizer,
device,
loss_fn=nn.MSELoss(),
)
write_log(log_file_path, "\n===== STEP 3: Final Held-Out Test Evaluation =====")
final_test_graphs = [g.clone() for g in test_graphs]
for g in final_test_graphs:
g.y = torch.tensor(
final_scaler.transform(g.y.reshape(1, -1)), dtype=torch.float
)
test_loader = DataLoader(
final_test_graphs, batch_size=final_best_params["batch_size"]
)
y_true_test, y_pred_test = [], []
final_model.eval()
with torch.no_grad():
for batch in test_loader:
batch = batch.to(device)
out = final_model(batch).view(-1)
y_true_test.extend(
final_scaler.inverse_transform(
batch.y.cpu().numpy().reshape(-1, 1)
).ravel()
)
y_pred_test.extend(
final_scaler.inverse_transform(out.cpu().numpy().reshape(-1, 1)).ravel()
)
y_true_test, y_pred_test = np.array(y_true_test), np.array(y_pred_test)
rmse_mean, rmse_low, rmse_high = bootstrap_metric(
y_true_test, y_pred_test, rmse_func
)
mae_mean, mae_low, mae_high = bootstrap_metric(y_true_test, y_pred_test, mae_func)
write_log(
log_file_path,
f"Test RMSE: {rmse_mean:.4f} (95% CI: [{rmse_low:.4f}, {rmse_high:.4f}])",
)
write_log(
log_file_path,
f"Test MAE: {mae_mean:.4f} (95% CI: [{mae_low:.4f}, {mae_high:.4f}])",
)
results_data = {
"val_rmse_mean": [mean_val_rmse],
"val_rmse_std": [std_val_rmse],
"val_mae_mean": [mean_val_mae],
"val_mae_std": [std_val_mae],
"test_rmse_mean": [rmse_mean],
"test_rmse_ci_low": [rmse_low],
"test_rmse_ci_high": [rmse_high],
"test_mae_mean": [mae_mean],
"test_mae_ci_low": [mae_low],
"test_mae_ci_high": [mae_high],
}
parent = Path(__file__).parent.parent.resolve().__str__()
log_dir = Path(
parent + "/" + "logs_hyperparameter" + "/" + f"{args.dataset}"
).__str__()
data_res_path = (
log_dir + "/" + f"{args.model}_{args.rep}_{args.dataset}_final_results.csv"
)
pd.DataFrame(results_data).to_csv(data_res_path, index=False)