File size: 5,222 Bytes
9a67fbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import torch
from torch_geometric.loader import DataLoader
from pathlib import Path
import numpy as np
import multiprocessing
from joblib import Parallel, delayed
# --- Assume these are imported from your project ---
from data.loaders import load_dataset
from data.featurize import (
smiles_to_graph,
selfies_to_graph,
ecfp_to_graph,
smiles_for_gp,
selfies_for_gp,
ecfp_for_gp,
)
from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles
from models.gnn import GCN, GIN, GAT, GraphSAGE
from models.polyatomic import PolyatomicNet
# --- Configuration Dictionaries ---
# Representation functions for GNN
REPRESENTATIONS = {
"smiles": smiles_to_graph,
"selfies": selfies_to_graph,
"ecfp": ecfp_to_graph,
"polyatomic": compressed_topsignal_graph_from_smiles,
}
# Feature functions for GP (numeric vectors)
GP_FEATURIZERS = {
"smiles": smiles_for_gp,
"selfies": selfies_for_gp,
"ecfp": ecfp_for_gp,
}
GNN_MODELS = {
"gcn": GCN,
"gin": GIN,
"gat": GAT,
"sage": GraphSAGE,
"polyatomic": PolyatomicNet, # Custom GNN for polyatomic complexes
}
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
pass
def featurize_dataset_parallel(X, y, featurizer, n_jobs=None):
"""Your provided parallel featurization function."""
if n_jobs is None:
n_jobs = max(1, multiprocessing.cpu_count() - 2)
if featurizer.__name__ == "compressed_topsignal_graph_from_smiles":
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
delayed(featurizer)(xi, yi) for xi, yi in zip(X, y)
)
else:
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
delayed(featurizer)(xi) for xi in X
)
# Filter out None results and attach the original y-value
data_list = []
# This robustly handles cases where some featurizations might fail
for i, res in enumerate(results):
if res is not None:
res.y = torch.tensor([y[i]], dtype=torch.float)
data_list.append(res)
return data_list
def prepare_and_load_data(args):
"""
Performs the expensive featurization ONCE and caches the result.
Subsequent runs will load the cached file instantly.
"""
root_dir = Path(__file__).parent.parent.resolve().__str__()
datasets_dir = root_dir + "/" + "datasets"
cache_dir = Path(datasets_dir)
if not os.path.exists(cache_dir):
cache_dir.mkdir(exist_ok=False)
train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt")
test_cache_file = Path(
f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt"
)
print(f"train cache file is: {train_cache_file}")
print(f"test cache file is: {test_cache_file}")
if train_cache_file.exists() and test_cache_file.exists():
print(
f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..."
)
train_graphs = torch.load(train_cache_file, weights_only=False)
test_graphs = torch.load(test_cache_file, weights_only=False)
return train_graphs, test_graphs
print("INFO: No cached data found. Starting one-time featurization process...")
X_train, X_test, y_train, y_test = load_dataset(args.dataset)
featurizer = REPRESENTATIONS[args.rep]
print("Featurizing training set (this may take a while)...")
train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer)
torch.save(train_graphs, train_cache_file)
print(f"Saved featurized training data to {train_cache_file}")
print("Featurizing test set...")
test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer)
torch.save(test_graphs, test_cache_file)
print(f"Saved featurized test data to {test_cache_file}")
return train_graphs, test_graphs
def get_model_instance(args, params, train_graphs):
"""Instantiates a model, handling the special case for polyatomic."""
model_class = GNN_MODELS[args.model]
sample_graph = train_graphs[0]
if args.model == "polyatomic":
from torch_geometric.utils import degree
print("INFO: Calculating degree vector for polyatomic model...")
loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128))
deg = torch.zeros(32, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
bc = torch.bincount(d, minlength=deg.size(0))
if bc.size(0) > deg.size(0):
new_deg = torch.zeros(bc.size(0), dtype=torch.long)
new_deg[: deg.size(0)] = deg
deg = new_deg
deg += bc
return model_class(
node_feat_dim=sample_graph.x.shape[1],
edge_feat_dim=sample_graph.edge_attr.shape[1],
graph_feat_dim=sample_graph.graph_feats.shape[0],
hidden_dim=params["hidden_dim"],
deg=deg,
)
else:
return model_class(
in_channels=sample_graph.num_node_features,
hidden_channels=params["hidden_dim"],
out_channels=1,
)
|