File size: 6,403 Bytes
9a67fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import numpy as np
from rdkit import Chem
import networkx as nx
from collections import defaultdict
from torch_geometric.data import Data
from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex
from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE


def compressed_topsignal_graph_from_smiles(
    smile: str, y_val: int, topk_lap: int = 5
) -> Data | None:
    try:
        # 1) Abstract complex
        pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract")
        ac = pg.smiles_to_geom_complex()
        assert isinstance(ac, AbstractComplex)

        # 2) RDKit molecule
        mol = Chem.MolFromSmiles(smile)  # type: ignore
        if mol is None:
            return None

        # 3) Node features: chain0 value + RDKit descriptors
        chains = ac.get_raw_k_chains()
        chain0 = chains.get("chain_0", [])
        atom_types = [6, 7, 8, 15, 16, 17]
        hyb_types = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
        ]
        node_feats = []
        for atom in mol.GetAtoms():
            idx = atom.GetIdx()
            # fallback if chain0 shorter than atom count
            c0 = float(chain0[idx]) if idx < len(chain0) else 0.0
            feats = [c0]
            feats += one_hot(atom.GetAtomicNum(), atom_types)
            feats += one_hot(atom.GetHybridization(), hyb_types)
            feats += [
                float(atom.GetDegree()),
                float(atom.GetIsAromatic()),
                float(atom.GetFormalCharge()),
            ]
            node_feats.append(feats)
        x = torch.tensor(node_feats, dtype=torch.float32)
        n = x.size(0)  # use number of atoms for all subsequent node counts

        # 4) Edges: abstract bonds + RDKit fallback
        sk = ac.get_skeleta().get("molecule_skeleta", [[]])[0]
        zero = next((lst for dim, lst in sk if dim == "0"), [])
        node_ids = [next(iter(fz))[0] for fz in zero]
        atom_map = defaultdict(list)
        for i, nid in enumerate(node_ids):
            symbol = nid.split("_")[0]
            atom_map[symbol].append(i)

        edge_index_list, edge_attr_list = [], []
        bond_types = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC,
        ]
        for a1, a2, (btype, order) in ac.get_bonds():
            bt_val = getattr(Chem.rdchem.BondType, btype, None)
            for i in atom_map.get(a1, []):
                for j in atom_map.get(a2, []):
                    if i < n and j < n:
                        edge_index_list += [[i, j], [j, i]]
                        attr = one_hot(bt_val, bond_types) + [float(order), 0.0]
                        edge_attr_list += [attr, attr]
        if not edge_index_list:
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                edge_index_list += [[i, j], [j, i]]
                attr = one_hot(bond.GetBondType(), bond_types)
                attr += [float(bond.GetIsConjugated()), float(bond.IsInRing())]
                edge_attr_list += [attr, attr]

        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32)

        # 5) Topology features: centrality + SPD
        G = nx.Graph()
        G.add_nodes_from(range(n))
        G.add_edges_from(edge_index_list)
        cent = nx.closeness_centrality(G)
        spd = dict(nx.all_pairs_shortest_path_length(G))
        cent_vec = [cent.get(i, 0.0) for i in range(n)]
        spd_vec = [
            sum(d.values()) / max(len(d), 1) for d in (spd.get(i, {}) for i in range(n))
        ]
        cent_t = torch.tensor(cent_vec, dtype=torch.float32).view(n, 1)
        spd_t = torch.tensor(spd_vec, dtype=torch.float32).view(n, 1)
        x = torch.cat([x, cent_t, spd_t], dim=1)

        # print("MANAGED TO CONCAT?")

        # 6) Graph-level features: chain stats + laplacians
        g_stats, lap_feats = [], []
        for k, arr in chains.items():
            if k == "chain_0":
                continue
            a = np.array(arr, dtype=np.float32)
            g_stats += [a.mean(), a.std()]

        # print("COMPUTED GRAP STATS")

        for grp in ac.get_laplacians().get("molecule_laplacians", []):
            recs = grp if isinstance(grp, list) else [grp]
            for _, mat in recs:
                # use dense eigen solver to avoid ARPACK issues
                M = np.array(mat, dtype=np.float32)
                # compute eigenvalues of symmetric Laplacian
                try:
                    eigs = np.linalg.eigvalsh(M)
                except Exception:
                    eigs = np.zeros(M.shape[0], dtype=np.float32)
                # take smallest non-zero eigenvalues (skip the first zero)
                nonzero = eigs[eigs > 1e-6]
                vals = nonzero[:topk_lap] if len(nonzero) >= topk_lap else nonzero
                # pad to exactly topk_lap
                if len(vals) < topk_lap:
                    vals = np.pad(vals, (0, topk_lap - len(vals)))
                lap_feats += list(vals)

        # --- 7) Spectral k-chains stats ---
        spectral = ac.get_spectral_k_chains()
        spec_feats = []
        for arr in spectral.values():
            a = np.array(arr, dtype=np.float32)
            spec_feats += [a.mean(), a.std()]

        # --- 8) Betti numbers (components & cycles) ---
        b0 = nx.number_connected_components(G)
        b1 = sum(
            len(nx.cycle_basis(G.subgraph(comp))) for comp in nx.connected_components(G)
        )

        # --- 9) Assemble graph_feats ---
        all_feats = g_stats + lap_feats + spec_feats + [float(b0), float(b1)]
        graph_feats = torch.tensor(all_feats, dtype=torch.float32)

        # print("managed to feat?")

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data.graph_feats = graph_feats
        data.y = torch.tensor([y_val], dtype=torch.float)
        # print(f"SUCCESS for : {smile}")
        return data
    except Exception as e:
        # print(f"Failed {smile}: {e}")
        return None


def one_hot(val, choices):
    return [1.0 if val == c else 0.0 for c in choices]