|  | import os | 
					
						
						|  | import torch | 
					
						
						|  | from torch_geometric.loader import DataLoader | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | import numpy as np | 
					
						
						|  | import multiprocessing | 
					
						
						|  | from joblib import Parallel, delayed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from data.loaders import load_dataset | 
					
						
						|  | from data.featurize import ( | 
					
						
						|  | smiles_to_graph, | 
					
						
						|  | selfies_to_graph, | 
					
						
						|  | ecfp_to_graph, | 
					
						
						|  | smiles_for_gp, | 
					
						
						|  | selfies_for_gp, | 
					
						
						|  | ecfp_for_gp, | 
					
						
						|  | ) | 
					
						
						|  | from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles | 
					
						
						|  | from models.gnn import GCN, GIN, GAT, GraphSAGE | 
					
						
						|  | from models.polyatomic import PolyatomicNet | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | REPRESENTATIONS = { | 
					
						
						|  | "smiles": smiles_to_graph, | 
					
						
						|  | "selfies": selfies_to_graph, | 
					
						
						|  | "ecfp": ecfp_to_graph, | 
					
						
						|  | "polyatomic": compressed_topsignal_graph_from_smiles, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | GP_FEATURIZERS = { | 
					
						
						|  | "smiles": smiles_for_gp, | 
					
						
						|  | "selfies": selfies_for_gp, | 
					
						
						|  | "ecfp": ecfp_for_gp, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | GNN_MODELS = { | 
					
						
						|  | "gcn": GCN, | 
					
						
						|  | "gin": GIN, | 
					
						
						|  | "gat": GAT, | 
					
						
						|  | "sage": GraphSAGE, | 
					
						
						|  | "polyatomic": PolyatomicNet, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | multiprocessing.set_start_method("fork") | 
					
						
						|  | except RuntimeError: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def featurize_dataset_parallel(X, y, featurizer, n_jobs=None): | 
					
						
						|  | """Your provided parallel featurization function.""" | 
					
						
						|  | if n_jobs is None: | 
					
						
						|  | n_jobs = max(1, multiprocessing.cpu_count() - 2) | 
					
						
						|  |  | 
					
						
						|  | if featurizer.__name__ == "compressed_topsignal_graph_from_smiles": | 
					
						
						|  | results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( | 
					
						
						|  | delayed(featurizer)(xi, yi) for xi, yi in zip(X, y) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( | 
					
						
						|  | delayed(featurizer)(xi) for xi in X | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data_list = [] | 
					
						
						|  |  | 
					
						
						|  | for i, res in enumerate(results): | 
					
						
						|  | if res is not None: | 
					
						
						|  | res.y = torch.tensor([y[i]], dtype=torch.float) | 
					
						
						|  | data_list.append(res) | 
					
						
						|  |  | 
					
						
						|  | return data_list | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def prepare_and_load_data(args): | 
					
						
						|  | """ | 
					
						
						|  | Performs the expensive featurization ONCE and caches the result. | 
					
						
						|  | Subsequent runs will load the cached file instantly. | 
					
						
						|  | """ | 
					
						
						|  | root_dir = Path(__file__).parent.parent.resolve().__str__() | 
					
						
						|  | datasets_dir = root_dir + "/" + "datasets" | 
					
						
						|  | cache_dir = Path(datasets_dir) | 
					
						
						|  | if not os.path.exists(cache_dir): | 
					
						
						|  | cache_dir.mkdir(exist_ok=False) | 
					
						
						|  | train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt") | 
					
						
						|  | test_cache_file = Path( | 
					
						
						|  | f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print(f"train cache file is: {train_cache_file}") | 
					
						
						|  | print(f"test cache file is: {test_cache_file}") | 
					
						
						|  |  | 
					
						
						|  | if train_cache_file.exists() and test_cache_file.exists(): | 
					
						
						|  | print( | 
					
						
						|  | f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..." | 
					
						
						|  | ) | 
					
						
						|  | train_graphs = torch.load(train_cache_file, weights_only=False) | 
					
						
						|  | test_graphs = torch.load(test_cache_file, weights_only=False) | 
					
						
						|  | return train_graphs, test_graphs | 
					
						
						|  |  | 
					
						
						|  | print("INFO: No cached data found. Starting one-time featurization process...") | 
					
						
						|  | X_train, X_test, y_train, y_test = load_dataset(args.dataset) | 
					
						
						|  | featurizer = REPRESENTATIONS[args.rep] | 
					
						
						|  |  | 
					
						
						|  | print("Featurizing training set (this may take a while)...") | 
					
						
						|  | train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer) | 
					
						
						|  | torch.save(train_graphs, train_cache_file) | 
					
						
						|  | print(f"Saved featurized training data to {train_cache_file}") | 
					
						
						|  |  | 
					
						
						|  | print("Featurizing test set...") | 
					
						
						|  | test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer) | 
					
						
						|  | torch.save(test_graphs, test_cache_file) | 
					
						
						|  | print(f"Saved featurized test data to {test_cache_file}") | 
					
						
						|  |  | 
					
						
						|  | return train_graphs, test_graphs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_model_instance(args, params, train_graphs): | 
					
						
						|  | """Instantiates a model, handling the special case for polyatomic.""" | 
					
						
						|  | model_class = GNN_MODELS[args.model] | 
					
						
						|  | sample_graph = train_graphs[0] | 
					
						
						|  |  | 
					
						
						|  | if args.model == "polyatomic": | 
					
						
						|  | from torch_geometric.utils import degree | 
					
						
						|  |  | 
					
						
						|  | print("INFO: Calculating degree vector for polyatomic model...") | 
					
						
						|  | loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128)) | 
					
						
						|  | deg = torch.zeros(32, dtype=torch.long) | 
					
						
						|  | for data in loader: | 
					
						
						|  | d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) | 
					
						
						|  | bc = torch.bincount(d, minlength=deg.size(0)) | 
					
						
						|  | if bc.size(0) > deg.size(0): | 
					
						
						|  | new_deg = torch.zeros(bc.size(0), dtype=torch.long) | 
					
						
						|  | new_deg[: deg.size(0)] = deg | 
					
						
						|  | deg = new_deg | 
					
						
						|  | deg += bc | 
					
						
						|  | return model_class( | 
					
						
						|  | node_feat_dim=sample_graph.x.shape[1], | 
					
						
						|  | edge_feat_dim=sample_graph.edge_attr.shape[1], | 
					
						
						|  | graph_feat_dim=sample_graph.graph_feats.shape[0], | 
					
						
						|  | hidden_dim=params["hidden_dim"], | 
					
						
						|  | deg=deg, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | return model_class( | 
					
						
						|  | in_channels=sample_graph.num_node_features, | 
					
						
						|  | hidden_channels=params["hidden_dim"], | 
					
						
						|  | out_channels=1, | 
					
						
						|  | ) | 
					
						
						|  |  |