File size: 5,222 Bytes
9a67fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
from torch_geometric.loader import DataLoader
from pathlib import Path
import numpy as np
import multiprocessing
from joblib import Parallel, delayed

# --- Assume these are imported from your project ---
from data.loaders import load_dataset
from data.featurize import (
    smiles_to_graph,
    selfies_to_graph,
    ecfp_to_graph,
    smiles_for_gp,
    selfies_for_gp,
    ecfp_for_gp,
)
from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles
from models.gnn import GCN, GIN, GAT, GraphSAGE
from models.polyatomic import PolyatomicNet

# --- Configuration Dictionaries ---
# Representation functions for GNN
REPRESENTATIONS = {
    "smiles": smiles_to_graph,
    "selfies": selfies_to_graph,
    "ecfp": ecfp_to_graph,
    "polyatomic": compressed_topsignal_graph_from_smiles,
}

# Feature functions for GP (numeric vectors)
GP_FEATURIZERS = {
    "smiles": smiles_for_gp,
    "selfies": selfies_for_gp,
    "ecfp": ecfp_for_gp,
}

GNN_MODELS = {
    "gcn": GCN,
    "gin": GIN,
    "gat": GAT,
    "sage": GraphSAGE,
    "polyatomic": PolyatomicNet,  # Custom GNN for polyatomic complexes
}

try:
    multiprocessing.set_start_method("fork")
except RuntimeError:
    pass


def featurize_dataset_parallel(X, y, featurizer, n_jobs=None):
    """Your provided parallel featurization function."""
    if n_jobs is None:
        n_jobs = max(1, multiprocessing.cpu_count() - 2)

    if featurizer.__name__ == "compressed_topsignal_graph_from_smiles":
        results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
            delayed(featurizer)(xi, yi) for xi, yi in zip(X, y)
        )
    else:
        results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)(
            delayed(featurizer)(xi) for xi in X
        )

    # Filter out None results and attach the original y-value
    data_list = []
    # This robustly handles cases where some featurizations might fail
    for i, res in enumerate(results):
        if res is not None:
            res.y = torch.tensor([y[i]], dtype=torch.float)
            data_list.append(res)

    return data_list


def prepare_and_load_data(args):
    """
    Performs the expensive featurization ONCE and caches the result.
    Subsequent runs will load the cached file instantly.
    """
    root_dir = Path(__file__).parent.parent.resolve().__str__()
    datasets_dir = root_dir + "/" + "datasets"
    cache_dir = Path(datasets_dir)
    if not os.path.exists(cache_dir):
        cache_dir.mkdir(exist_ok=False)
    train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt")
    test_cache_file = Path(
        f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt"
    )

    print(f"train cache file is: {train_cache_file}")
    print(f"test cache file is: {test_cache_file}")

    if train_cache_file.exists() and test_cache_file.exists():
        print(
            f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..."
        )
        train_graphs = torch.load(train_cache_file, weights_only=False)
        test_graphs = torch.load(test_cache_file, weights_only=False)
        return train_graphs, test_graphs

    print("INFO: No cached data found. Starting one-time featurization process...")
    X_train, X_test, y_train, y_test = load_dataset(args.dataset)
    featurizer = REPRESENTATIONS[args.rep]

    print("Featurizing training set (this may take a while)...")
    train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer)
    torch.save(train_graphs, train_cache_file)
    print(f"Saved featurized training data to {train_cache_file}")

    print("Featurizing test set...")
    test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer)
    torch.save(test_graphs, test_cache_file)
    print(f"Saved featurized test data to {test_cache_file}")

    return train_graphs, test_graphs


def get_model_instance(args, params, train_graphs):
    """Instantiates a model, handling the special case for polyatomic."""
    model_class = GNN_MODELS[args.model]
    sample_graph = train_graphs[0]

    if args.model == "polyatomic":
        from torch_geometric.utils import degree

        print("INFO: Calculating degree vector for polyatomic model...")
        loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128))
        deg = torch.zeros(32, dtype=torch.long)
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
            bc = torch.bincount(d, minlength=deg.size(0))
            if bc.size(0) > deg.size(0):
                new_deg = torch.zeros(bc.size(0), dtype=torch.long)
                new_deg[: deg.size(0)] = deg
                deg = new_deg
            deg += bc
        return model_class(
            node_feat_dim=sample_graph.x.shape[1],
            edge_feat_dim=sample_graph.edge_attr.shape[1],
            graph_feat_dim=sample_graph.graph_feats.shape[0],
            hidden_dim=params["hidden_dim"],
            deg=deg,
        )
    else:
        return model_class(
            in_channels=sample_graph.num_node_features,
            hidden_channels=params["hidden_dim"],
            out_channels=1,
        )