alarv commited on
Commit
a5c6e6d
·
verified ·
1 Parent(s): 70e9e7f

Upload tbiodeg AttentiveFP model

Browse files
Files changed (5) hide show
  1. README.md +187 -0
  2. config.json +22 -0
  3. inference.py +127 -0
  4. pytorch_model.pt +3 -0
  5. requirements.txt +4 -0
README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - chemistry
5
+ - molecular-property-prediction
6
+ - graph-neural-networks
7
+ - attentivefp
8
+ - pytorch-geometric
9
+ - toxicity-prediction
10
+ language:
11
+ - en
12
+ pipeline_tag: tabular-regression
13
+ ---
14
+
15
+ # Pyrosage tbiodeg AttentiveFP Model
16
+
17
+ ## Model Description
18
+
19
+ This is an AttentiveFP (Attention-based Fingerprint) Graph Neural Network model trained for tbiodeg regression from the Pyrosage project. The model predicts molecular properties directly from SMILES strings using graph neural networks.
20
+
21
+ ## Model Details
22
+
23
+ - **Model Type**: AttentiveFP (Graph Neural Network)
24
+ - **Task**: Regression
25
+ - **Input**: SMILES strings (molecular representations)
26
+ - **Output**: Continuous numerical value
27
+ - **Framework**: PyTorch Geometric
28
+ - **Architecture**: AttentiveFP with enhanced atom and bond features
29
+
30
+ ### Hyperparameters
31
+
32
+ ```json
33
+ {
34
+ "name": "baseline",
35
+ "hidden_channels": 64,
36
+ "num_layers": 2,
37
+ "num_timesteps": 2,
38
+ "dropout": 0.2,
39
+ "learning_rate": 0.001,
40
+ "weight_decay": 1e-05,
41
+ "batch_size": 32,
42
+ "epochs": 50,
43
+ "patience": 10
44
+ }
45
+ ```
46
+
47
+ ## Usage
48
+
49
+ ### Installation
50
+
51
+ ```bash
52
+ pip install torch torch-geometric rdkit-pypi
53
+ ```
54
+
55
+ ### Loading the Model
56
+
57
+ ```python
58
+ import torch
59
+ from torch_geometric.nn import AttentiveFP
60
+ from rdkit import Chem
61
+ from torch_geometric.data import Data
62
+
63
+ # Load the model
64
+ model_dict = torch.load('pytorch_model.pt', map_location='cpu')
65
+ state_dict = model_dict['model_state_dict']
66
+ hyperparams = model_dict['hyperparameters']
67
+
68
+ # Create model with correct architecture
69
+ model = AttentiveFP(
70
+ in_channels=10, # Enhanced atom features
71
+ hidden_channels=hyperparams["hidden_channels"],
72
+ out_channels=1,
73
+ edge_dim=6, # Enhanced bond features
74
+ num_layers=hyperparams["num_layers"],
75
+ num_timesteps=hyperparams["num_timesteps"],
76
+ dropout=hyperparams["dropout"],
77
+ )
78
+
79
+ model.load_state_dict(state_dict)
80
+ model.eval()
81
+ ```
82
+
83
+ ### Making Predictions
84
+
85
+ ```python
86
+ def smiles_to_data(smiles):
87
+ """Convert SMILES string to PyG Data object"""
88
+ mol = Chem.MolFromSmiles(smiles)
89
+ if mol is None:
90
+ return None
91
+
92
+ # Enhanced atom features (10 dimensions)
93
+ atom_features = []
94
+ for atom in mol.GetAtoms():
95
+ features = [
96
+ atom.GetAtomicNum(),
97
+ atom.GetTotalDegree(),
98
+ atom.GetFormalCharge(),
99
+ atom.GetTotalNumHs(),
100
+ atom.GetNumRadicalElectrons(),
101
+ int(atom.GetIsAromatic()),
102
+ int(atom.IsInRing()),
103
+ # Hybridization as one-hot (3 dimensions)
104
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP),
105
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2),
106
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3)
107
+ ]
108
+ atom_features.append(features)
109
+
110
+ x = torch.tensor(atom_features, dtype=torch.float)
111
+
112
+ # Enhanced bond features (6 dimensions)
113
+ edges_list = []
114
+ edge_features = []
115
+ for bond in mol.GetBonds():
116
+ i = bond.GetBeginAtomIdx()
117
+ j = bond.GetEndAtomIdx()
118
+ edges_list.extend([[i, j], [j, i]])
119
+
120
+ features = [
121
+ # Bond type as one-hot (4 dimensions)
122
+ int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
123
+ int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
124
+ int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
125
+ int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
126
+ # Additional features (2 dimensions)
127
+ int(bond.GetIsConjugated()),
128
+ int(bond.IsInRing())
129
+ ]
130
+ edge_features.extend([features, features])
131
+
132
+ if not edges_list:
133
+ return None
134
+
135
+ edge_index = torch.tensor(edges_list, dtype=torch.long).t()
136
+ edge_attr = torch.tensor(edge_features, dtype=torch.float)
137
+
138
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
139
+
140
+ def predict(model, smiles):
141
+ """Make prediction for a SMILES string"""
142
+ data = smiles_to_data(smiles)
143
+ if data is None:
144
+ return None
145
+
146
+ batch = torch.zeros(data.num_nodes, dtype=torch.long)
147
+ with torch.no_grad():
148
+ output = model(data.x, data.edge_index, data.edge_attr, batch)
149
+ return output.item()
150
+
151
+ # Example usage
152
+ smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
153
+ prediction = predict(model, smiles)
154
+ print(f"Prediction for {smiles}: {prediction}")
155
+ ```
156
+
157
+ ## Training Data
158
+
159
+ The model was trained on the tbiodeg dataset from the Pyrosage project, which focuses on molecular toxicity and environmental property prediction.
160
+
161
+ ## Model Performance
162
+
163
+ See training logs for detailed performance metrics.
164
+
165
+ ## Limitations
166
+
167
+ - The model is trained on specific chemical datasets and may not generalize to all molecular types
168
+ - Performance may vary for molecules significantly different from the training distribution
169
+ - Requires proper SMILES string format for input
170
+
171
+ ## Citation
172
+
173
+ If you use this model, please cite the Pyrosage project:
174
+
175
+ ```bibtex
176
+ @misc{pyrosagetbiodeg,
177
+ title={Pyrosage tbiodeg AttentiveFP Model},
178
+ author={UPCI NTUA},
179
+ year={2025},
180
+ publisher={Hugging Face},
181
+ url={https://huggingface.co/upci-ntua/pyrosage-tbiodeg-attentivefp}
182
+ }
183
+ ```
184
+
185
+ ## License
186
+
187
+ MIT License - see LICENSE file for details.
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "AttentiveFP",
3
+ "task_type": "regression",
4
+ "endpoint": "tbiodeg",
5
+ "hyperparameters": {
6
+ "name": "baseline",
7
+ "hidden_channels": 64,
8
+ "num_layers": 2,
9
+ "num_timesteps": 2,
10
+ "dropout": 0.2,
11
+ "learning_rate": 0.001,
12
+ "weight_decay": 1e-05,
13
+ "batch_size": 32,
14
+ "epochs": 50,
15
+ "patience": 10
16
+ },
17
+ "input_features": {
18
+ "atom_features": 10,
19
+ "bond_features": 6
20
+ },
21
+ "framework": "pytorch_geometric"
22
+ }
inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone inference script for Pyrosage tbiodeg AttentiveFP Model
4
+ Usage: python inference.py "SMILES_STRING"
5
+ """
6
+
7
+ import sys
8
+ import torch
9
+ from torch_geometric.nn import AttentiveFP
10
+ from rdkit import Chem
11
+ from torch_geometric.data import Data
12
+
13
+
14
+ def smiles_to_data(smiles):
15
+ """Convert SMILES string to PyG Data object with enhanced features"""
16
+ mol = Chem.MolFromSmiles(smiles)
17
+ if mol is None:
18
+ return None
19
+
20
+ # Enhanced atom features (10 dimensions)
21
+ atom_features = []
22
+ for atom in mol.GetAtoms():
23
+ features = [
24
+ atom.GetAtomicNum(),
25
+ atom.GetTotalDegree(),
26
+ atom.GetFormalCharge(),
27
+ atom.GetTotalNumHs(),
28
+ atom.GetNumRadicalElectrons(),
29
+ int(atom.GetIsAromatic()),
30
+ int(atom.IsInRing()),
31
+ # Hybridization as one-hot (3 dimensions)
32
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP),
33
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2),
34
+ int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3)
35
+ ]
36
+ atom_features.append(features)
37
+
38
+ x = torch.tensor(atom_features, dtype=torch.float)
39
+
40
+ # Enhanced bond features (6 dimensions)
41
+ edges_list = []
42
+ edge_features = []
43
+ for bond in mol.GetBonds():
44
+ i = bond.GetBeginAtomIdx()
45
+ j = bond.GetEndAtomIdx()
46
+ edges_list.extend([[i, j], [j, i]])
47
+
48
+ features = [
49
+ # Bond type as one-hot (4 dimensions)
50
+ int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
51
+ int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
52
+ int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
53
+ int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
54
+ # Additional features (2 dimensions)
55
+ int(bond.GetIsConjugated()),
56
+ int(bond.IsInRing())
57
+ ]
58
+ edge_features.extend([features, features])
59
+
60
+ if not edges_list:
61
+ return None
62
+
63
+ edge_index = torch.tensor(edges_list, dtype=torch.long).t()
64
+ edge_attr = torch.tensor(edge_features, dtype=torch.float)
65
+
66
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
67
+
68
+
69
+ def load_model():
70
+ """Load the AttentiveFP model"""
71
+ model_dict = torch.load('pytorch_model.pt', map_location='cpu')
72
+ state_dict = model_dict['model_state_dict']
73
+ hyperparams = model_dict['hyperparameters']
74
+
75
+ model = AttentiveFP(
76
+ in_channels=10, # Enhanced atom features
77
+ hidden_channels=hyperparams["hidden_channels"],
78
+ out_channels=1,
79
+ edge_dim=6, # Enhanced bond features
80
+ num_layers=hyperparams["num_layers"],
81
+ num_timesteps=hyperparams["num_timesteps"],
82
+ dropout=hyperparams["dropout"],
83
+ )
84
+
85
+ model.load_state_dict(state_dict)
86
+ model.eval()
87
+ return model
88
+
89
+
90
+ def predict(model, smiles):
91
+ """Make prediction for a SMILES string"""
92
+ data = smiles_to_data(smiles)
93
+ if data is None:
94
+ return None
95
+
96
+ batch = torch.zeros(data.num_nodes, dtype=torch.long)
97
+ with torch.no_grad():
98
+ output = model(data.x, data.edge_index, data.edge_attr, batch)
99
+ return output.item()
100
+
101
+
102
+ def main():
103
+ if len(sys.argv) != 2:
104
+ print("Usage: python inference.py 'SMILES_STRING'")
105
+ print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'")
106
+ sys.exit(1)
107
+
108
+ smiles = sys.argv[1]
109
+ print(f"Loading tbiodeg AttentiveFP model...")
110
+
111
+ try:
112
+ model = load_model()
113
+ print(f"Making prediction for: {smiles}")
114
+
115
+ prediction = predict(model, smiles)
116
+ if prediction is not None:
117
+ print(f'Regression result: {prediction:.4f}')
118
+ else:
119
+ print("Error: Could not process SMILES string")
120
+
121
+ except Exception as e:
122
+ print(f"Error: {e}")
123
+ sys.exit(1)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e422936990a4d3b3458585b80de6f7475e618120c922307f2c18314b58ae2d4
3
+ size 383007
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torch-geometric>=2.0.0
3
+ rdkit-pypi>=2022.3.0
4
+ numpy>=1.21.0