|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | from rdkit import Chem | 
					
						
						|  | import networkx as nx | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  | from torch_geometric.data import Data | 
					
						
						|  | from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex | 
					
						
						|  | from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compressed_topsignal_graph_from_smiles( | 
					
						
						|  | smile: str, y_val: int, topk_lap: int = 5 | 
					
						
						|  | ) -> Data | None: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract") | 
					
						
						|  | ac = pg.smiles_to_geom_complex() | 
					
						
						|  | assert isinstance(ac, AbstractComplex) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mol = Chem.MolFromSmiles(smile) | 
					
						
						|  | if mol is None: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | chains = ac.get_raw_k_chains() | 
					
						
						|  | chain0 = chains.get("chain_0", []) | 
					
						
						|  | atom_types = [6, 7, 8, 15, 16, 17] | 
					
						
						|  | hyb_types = [ | 
					
						
						|  | Chem.rdchem.HybridizationType.SP, | 
					
						
						|  | Chem.rdchem.HybridizationType.SP2, | 
					
						
						|  | Chem.rdchem.HybridizationType.SP3, | 
					
						
						|  | ] | 
					
						
						|  | node_feats = [] | 
					
						
						|  | for atom in mol.GetAtoms(): | 
					
						
						|  | idx = atom.GetIdx() | 
					
						
						|  |  | 
					
						
						|  | c0 = float(chain0[idx]) if idx < len(chain0) else 0.0 | 
					
						
						|  | feats = [c0] | 
					
						
						|  | feats += one_hot(atom.GetAtomicNum(), atom_types) | 
					
						
						|  | feats += one_hot(atom.GetHybridization(), hyb_types) | 
					
						
						|  | feats += [ | 
					
						
						|  | float(atom.GetDegree()), | 
					
						
						|  | float(atom.GetIsAromatic()), | 
					
						
						|  | float(atom.GetFormalCharge()), | 
					
						
						|  | ] | 
					
						
						|  | node_feats.append(feats) | 
					
						
						|  | x = torch.tensor(node_feats, dtype=torch.float32) | 
					
						
						|  | n = x.size(0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sk = ac.get_skeleta().get("molecule_skeleta", [[]])[0] | 
					
						
						|  | zero = next((lst for dim, lst in sk if dim == "0"), []) | 
					
						
						|  | node_ids = [next(iter(fz))[0] for fz in zero] | 
					
						
						|  | atom_map = defaultdict(list) | 
					
						
						|  | for i, nid in enumerate(node_ids): | 
					
						
						|  | symbol = nid.split("_")[0] | 
					
						
						|  | atom_map[symbol].append(i) | 
					
						
						|  |  | 
					
						
						|  | edge_index_list, edge_attr_list = [], [] | 
					
						
						|  | bond_types = [ | 
					
						
						|  | Chem.rdchem.BondType.SINGLE, | 
					
						
						|  | Chem.rdchem.BondType.DOUBLE, | 
					
						
						|  | Chem.rdchem.BondType.TRIPLE, | 
					
						
						|  | Chem.rdchem.BondType.AROMATIC, | 
					
						
						|  | ] | 
					
						
						|  | for a1, a2, (btype, order) in ac.get_bonds(): | 
					
						
						|  | bt_val = getattr(Chem.rdchem.BondType, btype, None) | 
					
						
						|  | for i in atom_map.get(a1, []): | 
					
						
						|  | for j in atom_map.get(a2, []): | 
					
						
						|  | if i < n and j < n: | 
					
						
						|  | edge_index_list += [[i, j], [j, i]] | 
					
						
						|  | attr = one_hot(bt_val, bond_types) + [float(order), 0.0] | 
					
						
						|  | edge_attr_list += [attr, attr] | 
					
						
						|  | if not edge_index_list: | 
					
						
						|  | for bond in mol.GetBonds(): | 
					
						
						|  | i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | 
					
						
						|  | edge_index_list += [[i, j], [j, i]] | 
					
						
						|  | attr = one_hot(bond.GetBondType(), bond_types) | 
					
						
						|  | attr += [float(bond.GetIsConjugated()), float(bond.IsInRing())] | 
					
						
						|  | edge_attr_list += [attr, attr] | 
					
						
						|  |  | 
					
						
						|  | edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous() | 
					
						
						|  | edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G = nx.Graph() | 
					
						
						|  | G.add_nodes_from(range(n)) | 
					
						
						|  | G.add_edges_from(edge_index_list) | 
					
						
						|  | cent = nx.closeness_centrality(G) | 
					
						
						|  | spd = dict(nx.all_pairs_shortest_path_length(G)) | 
					
						
						|  | cent_vec = [cent.get(i, 0.0) for i in range(n)] | 
					
						
						|  | spd_vec = [ | 
					
						
						|  | sum(d.values()) / max(len(d), 1) for d in (spd.get(i, {}) for i in range(n)) | 
					
						
						|  | ] | 
					
						
						|  | cent_t = torch.tensor(cent_vec, dtype=torch.float32).view(n, 1) | 
					
						
						|  | spd_t = torch.tensor(spd_vec, dtype=torch.float32).view(n, 1) | 
					
						
						|  | x = torch.cat([x, cent_t, spd_t], dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | g_stats, lap_feats = [], [] | 
					
						
						|  | for k, arr in chains.items(): | 
					
						
						|  | if k == "chain_0": | 
					
						
						|  | continue | 
					
						
						|  | a = np.array(arr, dtype=np.float32) | 
					
						
						|  | g_stats += [a.mean(), a.std()] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for grp in ac.get_laplacians().get("molecule_laplacians", []): | 
					
						
						|  | recs = grp if isinstance(grp, list) else [grp] | 
					
						
						|  | for _, mat in recs: | 
					
						
						|  |  | 
					
						
						|  | M = np.array(mat, dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | eigs = np.linalg.eigvalsh(M) | 
					
						
						|  | except Exception: | 
					
						
						|  | eigs = np.zeros(M.shape[0], dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  | nonzero = eigs[eigs > 1e-6] | 
					
						
						|  | vals = nonzero[:topk_lap] if len(nonzero) >= topk_lap else nonzero | 
					
						
						|  |  | 
					
						
						|  | if len(vals) < topk_lap: | 
					
						
						|  | vals = np.pad(vals, (0, topk_lap - len(vals))) | 
					
						
						|  | lap_feats += list(vals) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | spectral = ac.get_spectral_k_chains() | 
					
						
						|  | spec_feats = [] | 
					
						
						|  | for arr in spectral.values(): | 
					
						
						|  | a = np.array(arr, dtype=np.float32) | 
					
						
						|  | spec_feats += [a.mean(), a.std()] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | b0 = nx.number_connected_components(G) | 
					
						
						|  | b1 = sum( | 
					
						
						|  | len(nx.cycle_basis(G.subgraph(comp))) for comp in nx.connected_components(G) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_feats = g_stats + lap_feats + spec_feats + [float(b0), float(b1)] | 
					
						
						|  | graph_feats = torch.tensor(all_feats, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | 
					
						
						|  | data.graph_feats = graph_feats | 
					
						
						|  | data.y = torch.tensor([y_val], dtype=torch.float) | 
					
						
						|  |  | 
					
						
						|  | return data | 
					
						
						|  | except Exception as e: | 
					
						
						|  |  | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def one_hot(val, choices): | 
					
						
						|  | return [1.0 if val == c else 0.0 for c in choices] | 
					
						
						|  |  |