PACT-Net / data /data_handling.py
rk-random's picture
Upload folder using huggingface_hub
9a67fbe verified
raw
history blame
5.22 kB
import os
import torch
from torch_geometric.loader import DataLoader
from pathlib import Path
import numpy as np
import multiprocessing
from joblib import Parallel, delayed
# --- Assume these are imported from your project ---
from data.loaders import load_dataset
from data.featurize import (
smiles_to_graph,
selfies_to_graph,
ecfp_to_graph,
smiles_for_gp,
selfies_for_gp,
ecfp_for_gp,
)
from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles
from models.gnn import GCN, GIN, GAT, GraphSAGE
from models.polyatomic import PolyatomicNet
# --- Configuration Dictionaries ---
# Representation functions for GNN
REPRESENTATIONS = {
"smiles": smiles_to_graph,
"selfies": selfies_to_graph,
"ecfp": ecfp_to_graph,
"polyatomic": compressed_topsignal_graph_from_smiles,
}
# Feature functions for GP (numeric vectors)
GP_FEATURIZERS = {
"smiles": smiles_for_gp,
"selfies": selfies_for_gp,
"ecfp": ecfp_for_gp,
}
GNN_MODELS = {
"gcn": GCN,
"gin": GIN,
"gat": GAT,
"sage": GraphSAGE,
"polyatomic": PolyatomicNet, # Custom GNN for polyatomic complexes
}
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
pass
def featurize_dataset_parallel(X, y, featurizer, n_jobs=None):
"""Your provided parallel featurization function."""
if n_jobs is None:
n_jobs = max(1, multiprocessing.cpu_count() - 2)
if featurizer.__name__ == "compressed_topsignal_graph_from_smiles":
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
delayed(featurizer)(xi, yi) for xi, yi in zip(X, y)
)
else:
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
delayed(featurizer)(xi) for xi in X
)
# Filter out None results and attach the original y-value
data_list = []
# This robustly handles cases where some featurizations might fail
for i, res in enumerate(results):
if res is not None:
res.y = torch.tensor([y[i]], dtype=torch.float)
data_list.append(res)
return data_list
def prepare_and_load_data(args):
"""
Performs the expensive featurization ONCE and caches the result.
Subsequent runs will load the cached file instantly.
"""
root_dir = Path(__file__).parent.parent.resolve().__str__()
datasets_dir = root_dir + "/" + "datasets"
cache_dir = Path(datasets_dir)
if not os.path.exists(cache_dir):
cache_dir.mkdir(exist_ok=False)
train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt")
test_cache_file = Path(
f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt"
)
print(f"train cache file is: {train_cache_file}")
print(f"test cache file is: {test_cache_file}")
if train_cache_file.exists() and test_cache_file.exists():
print(
f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..."
)
train_graphs = torch.load(train_cache_file, weights_only=False)
test_graphs = torch.load(test_cache_file, weights_only=False)
return train_graphs, test_graphs
print("INFO: No cached data found. Starting one-time featurization process...")
X_train, X_test, y_train, y_test = load_dataset(args.dataset)
featurizer = REPRESENTATIONS[args.rep]
print("Featurizing training set (this may take a while)...")
train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer)
torch.save(train_graphs, train_cache_file)
print(f"Saved featurized training data to {train_cache_file}")
print("Featurizing test set...")
test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer)
torch.save(test_graphs, test_cache_file)
print(f"Saved featurized test data to {test_cache_file}")
return train_graphs, test_graphs
def get_model_instance(args, params, train_graphs):
"""Instantiates a model, handling the special case for polyatomic."""
model_class = GNN_MODELS[args.model]
sample_graph = train_graphs[0]
if args.model == "polyatomic":
from torch_geometric.utils import degree
print("INFO: Calculating degree vector for polyatomic model...")
loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128))
deg = torch.zeros(32, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
bc = torch.bincount(d, minlength=deg.size(0))
if bc.size(0) > deg.size(0):
new_deg = torch.zeros(bc.size(0), dtype=torch.long)
new_deg[: deg.size(0)] = deg
deg = new_deg
deg += bc
return model_class(
node_feat_dim=sample_graph.x.shape[1],
edge_feat_dim=sample_graph.edge_attr.shape[1],
graph_feat_dim=sample_graph.graph_feats.shape[0],
hidden_dim=params["hidden_dim"],
deg=deg,
)
else:
return model_class(
in_channels=sample_graph.num_node_features,
hidden_channels=params["hidden_dim"],
out_channels=1,
)