File size: 3,663 Bytes
9a67fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
import selfies as sf
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

mfpgen = GetMorganGenerator(
    radius=2,
    countSimulation=False,
    includeChirality=False,
    useBondTypes=True,
    onlyNonzeroInvariants=False,
    includeRingMembership=True,
    countBounds=None,
    fpSize=2048,
    atomInvariantsGenerator=None,
    bondInvariantsGenerator=None,
    includeRedundantEnvironments=False,
)


def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)  # type: ignore
    if mol is None:
        return None
    return mol_to_graph(mol)


def selfies_to_graph(smiles_string):
    try:
        selfies_string = sf.encoder(smiles_string)
        smiles = sf.decoder(selfies_string)
        mol = Chem.MolFromSmiles(smiles)  # type: ignore
        if mol is None:
            raise ValueError("Decoded SELFIES is invalid")
        return mol_to_graph(mol)
    except Exception:
        fallback = smiles_to_graph(smiles_string)
        if fallback is None:
            return None
        return fallback


def ecfp_to_graph(smiles_str: str, max_bits: int = 2048, k: int = 2) -> Data | None:
    mol = Chem.MolFromSmiles(smiles_str)  # type: ignore
    if mol is None:
        return None
    fp = mfpgen.GetFingerprintAsNumPy(mol)
    active_bits = np.nonzero(fp)[0]
    n = len(active_bits)
    if n == 0:
        return None
    edge_index = []
    for i in range(n):
        for j in range(i + 1, min(i + 1 + k, n)):
            edge_index.append([i, j])
            edge_index.append([j, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.zeros((n, max_bits), dtype=torch.float)
    for i, bit_idx in enumerate(active_bits):
        x[i, bit_idx] = 1.0
    return Data(x=x, edge_index=edge_index)


def mol_to_graph(mol):
    atom_feats = []
    for atom in mol.GetAtoms():
        atom_feats.append(
            [
                atom.GetAtomicNum(),
                atom.GetDegree(),
                atom.GetFormalCharge(),
                atom.GetIdx(),
            ]
        )
    x = torch.tensor(atom_feats, dtype=torch.float)

    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append((i, j))
        edge_index.append((j, i))
        btype = bond.GetBondTypeAsDouble()
        edge_attr.append([btype])
        edge_attr.append([btype])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def smiles_for_gp(smiles: str) -> np.ndarray:
    mol = Chem.MolFromSmiles(smiles)  # type: ignore
    if mol is None:
        return np.zeros(mfpgen.GetNumBits(), dtype=np.float32)
    arr = mfpgen.GetFingerprintAsNumPy(mol)
    return arr.astype(np.float32)


def selfies_for_gp(selfies_str, radius=2, n_bits=2048):
    try:
        smiles = sf.decoder(selfies_str)
        assert isinstance(smiles, str)
        return smiles_for_gp(smiles)
    except:
        return np.zeros(n_bits)


def ecfp_for_gp(smiles_str: str) -> np.ndarray:
    mol = Chem.MolFromSmiles(smiles_str)  # type: ignore
    if mol is None:
        return np.zeros(mfpgen.GetNumBits(), dtype=np.float32)
    return mfpgen.GetFingerprintAsNumPy(mol).astype(np.float32)


def graph_native_loader(graph_list, batch_size=32, shuffle=True):
    return DataLoader(graph_list, batch_size=batch_size, shuffle=shuffle)