|
|
import os |
|
|
import torch |
|
|
from torch_geometric.loader import DataLoader |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import multiprocessing |
|
|
from joblib import Parallel, delayed |
|
|
|
|
|
|
|
|
from data.loaders import load_dataset |
|
|
from data.featurize import ( |
|
|
smiles_to_graph, |
|
|
selfies_to_graph, |
|
|
ecfp_to_graph, |
|
|
smiles_for_gp, |
|
|
selfies_for_gp, |
|
|
ecfp_for_gp, |
|
|
) |
|
|
from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles |
|
|
from models.gnn import GCN, GIN, GAT, GraphSAGE |
|
|
from models.polyatomic import PolyatomicNet |
|
|
|
|
|
|
|
|
|
|
|
REPRESENTATIONS = { |
|
|
"smiles": smiles_to_graph, |
|
|
"selfies": selfies_to_graph, |
|
|
"ecfp": ecfp_to_graph, |
|
|
"polyatomic": compressed_topsignal_graph_from_smiles, |
|
|
} |
|
|
|
|
|
|
|
|
GP_FEATURIZERS = { |
|
|
"smiles": smiles_for_gp, |
|
|
"selfies": selfies_for_gp, |
|
|
"ecfp": ecfp_for_gp, |
|
|
} |
|
|
|
|
|
GNN_MODELS = { |
|
|
"gcn": GCN, |
|
|
"gin": GIN, |
|
|
"gat": GAT, |
|
|
"sage": GraphSAGE, |
|
|
"polyatomic": PolyatomicNet, |
|
|
} |
|
|
|
|
|
try: |
|
|
multiprocessing.set_start_method("fork") |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
|
|
|
def featurize_dataset_parallel(X, y, featurizer, n_jobs=None): |
|
|
"""Your provided parallel featurization function.""" |
|
|
if n_jobs is None: |
|
|
n_jobs = max(1, multiprocessing.cpu_count() - 2) |
|
|
|
|
|
if featurizer.__name__ == "compressed_topsignal_graph_from_smiles": |
|
|
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( |
|
|
delayed(featurizer)(xi, yi) for xi, yi in zip(X, y) |
|
|
) |
|
|
else: |
|
|
results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( |
|
|
delayed(featurizer)(xi) for xi in X |
|
|
) |
|
|
|
|
|
|
|
|
data_list = [] |
|
|
|
|
|
for i, res in enumerate(results): |
|
|
if res is not None: |
|
|
res.y = torch.tensor([y[i]], dtype=torch.float) |
|
|
data_list.append(res) |
|
|
|
|
|
return data_list |
|
|
|
|
|
|
|
|
def prepare_and_load_data(args): |
|
|
""" |
|
|
Performs the expensive featurization ONCE and caches the result. |
|
|
Subsequent runs will load the cached file instantly. |
|
|
""" |
|
|
root_dir = Path(__file__).parent.parent.resolve().__str__() |
|
|
datasets_dir = root_dir + "/" + "datasets" |
|
|
cache_dir = Path(datasets_dir) |
|
|
if not os.path.exists(cache_dir): |
|
|
cache_dir.mkdir(exist_ok=False) |
|
|
train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt") |
|
|
test_cache_file = Path( |
|
|
f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt" |
|
|
) |
|
|
|
|
|
print(f"train cache file is: {train_cache_file}") |
|
|
print(f"test cache file is: {test_cache_file}") |
|
|
|
|
|
if train_cache_file.exists() and test_cache_file.exists(): |
|
|
print( |
|
|
f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..." |
|
|
) |
|
|
train_graphs = torch.load(train_cache_file, weights_only=False) |
|
|
test_graphs = torch.load(test_cache_file, weights_only=False) |
|
|
return train_graphs, test_graphs |
|
|
|
|
|
print("INFO: No cached data found. Starting one-time featurization process...") |
|
|
X_train, X_test, y_train, y_test = load_dataset(args.dataset) |
|
|
featurizer = REPRESENTATIONS[args.rep] |
|
|
|
|
|
print("Featurizing training set (this may take a while)...") |
|
|
train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer) |
|
|
torch.save(train_graphs, train_cache_file) |
|
|
print(f"Saved featurized training data to {train_cache_file}") |
|
|
|
|
|
print("Featurizing test set...") |
|
|
test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer) |
|
|
torch.save(test_graphs, test_cache_file) |
|
|
print(f"Saved featurized test data to {test_cache_file}") |
|
|
|
|
|
return train_graphs, test_graphs |
|
|
|
|
|
|
|
|
def get_model_instance(args, params, train_graphs): |
|
|
"""Instantiates a model, handling the special case for polyatomic.""" |
|
|
model_class = GNN_MODELS[args.model] |
|
|
sample_graph = train_graphs[0] |
|
|
|
|
|
if args.model == "polyatomic": |
|
|
from torch_geometric.utils import degree |
|
|
|
|
|
print("INFO: Calculating degree vector for polyatomic model...") |
|
|
loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128)) |
|
|
deg = torch.zeros(32, dtype=torch.long) |
|
|
for data in loader: |
|
|
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) |
|
|
bc = torch.bincount(d, minlength=deg.size(0)) |
|
|
if bc.size(0) > deg.size(0): |
|
|
new_deg = torch.zeros(bc.size(0), dtype=torch.long) |
|
|
new_deg[: deg.size(0)] = deg |
|
|
deg = new_deg |
|
|
deg += bc |
|
|
return model_class( |
|
|
node_feat_dim=sample_graph.x.shape[1], |
|
|
edge_feat_dim=sample_graph.edge_attr.shape[1], |
|
|
graph_feat_dim=sample_graph.graph_feats.shape[0], |
|
|
hidden_dim=params["hidden_dim"], |
|
|
deg=deg, |
|
|
) |
|
|
else: |
|
|
return model_class( |
|
|
in_channels=sample_graph.num_node_features, |
|
|
hidden_channels=params["hidden_dim"], |
|
|
out_channels=1, |
|
|
) |
|
|
|