AReUReDi / smiles /smiles_classifiers.py
Tong Chen
add files
295b1cd
import sys
import os
sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles')
import xgboost as xgb
import torch
import numpy as np
from transformers import AutoModelForMaskedLM
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
import warnings
import numpy as np
import esm
import torch.nn as nn
from rdkit import Chem, rdBase, DataStructs
rdBase.DisableLog('rdApp.error')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class Analyzer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def get_scores(self, x):
"""Check if the SMILES represents a peptide structure"""
results = []
smiles_list = self.tokenizer.batch_decode(x)
for smiles in smiles_list:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
results.append(0)
continue
# Look for peptide bonds: NC(=O) pattern
peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
# Look for N-methylated peptide bonds: N(C)C(=O) pattern
n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
if mol.HasSubstructMatch(peptide_bond_pattern) or mol.HasSubstructMatch(n_methyl_pattern):
results.append(1)
else:
results.append(0)
return torch.tensor(results)
def __call__(self, x):
scores = self.get_scores(x)
return torch.tensor(scores)
class Hemolysis:
def __init__(self, device):
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/hemolysis-xgboost.json')
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device)
def get_scores(self, x):
scores = np.ones(len(x))
features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu())
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
probs = self.predictor.predict(features)
# return the probability of it being not hemolytic
return scores - probs
def __call__(self, x):
scores = self.get_scores(x)
return torch.tensor(scores)
class Nonfouling:
def __init__(self, device):
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/nonfouling-xgboost.json')
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device)
def get_scores(self, x):
scores = np.zeros(len(x))
features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu())
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
# return the probability of it being not hemolytic
return scores
def __call__(self, x):
scores = self.get_scores(x)
return torch.tensor(scores)
class Solubility:
def __init__(self, device):
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json')
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device)
def get_scores(self, x):
scores = np.zeros(len(x))
features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu())
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
return scores
def __call__(self, x):
scores = self.get_scores(x)
return torch.tensor(scores)
class ImprovedBindingPredictor(nn.Module):
def __init__(self,
esm_dim=1280,
smiles_dim=768,
hidden_dim=512,
n_heads=8,
n_layers=3,
dropout=0.1):
super().__init__()
# Define binding thresholds
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
# Project to same dimension
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
self.protein_norm = nn.LayerNorm(hidden_dim)
self.smiles_norm = nn.LayerNorm(hidden_dim)
# Cross attention blocks with layer norm
self.cross_attention_layers = nn.ModuleList([
nn.ModuleDict({
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
'norm1': nn.LayerNorm(hidden_dim),
'ffn': nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim)
),
'norm2': nn.LayerNorm(hidden_dim)
}) for _ in range(n_layers)
])
# Prediction heads
self.shared_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# Regression head
self.regression_head = nn.Linear(hidden_dim, 1)
# Classification head (3 classes: tight, medium, loose binding)
self.classification_head = nn.Linear(hidden_dim, 3)
def get_binding_class(self, affinity):
"""Convert affinity values to class indices
0: tight binding (>= 7.5)
1: medium binding (6.0-7.5)
2: weak binding (< 6.0)
"""
if isinstance(affinity, torch.Tensor):
tight_mask = affinity >= self.tight_threshold
weak_mask = affinity < self.weak_threshold
medium_mask = ~(tight_mask | weak_mask)
classes = torch.zeros_like(affinity, dtype=torch.long)
classes[medium_mask] = 1
classes[weak_mask] = 2
return classes
else:
if affinity >= self.tight_threshold:
return 0 # tight binding
elif affinity < self.weak_threshold:
return 2 # weak binding
else:
return 1 # medium binding
def forward(self, protein_emb, smiles_emb):
protein = self.protein_norm(self.protein_projection(protein_emb))
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
#protein = protein.transpose(0, 1)
#smiles = smiles.transpose(0, 1)
# Cross attention layers
for layer in self.cross_attention_layers:
# Protein attending to SMILES
attended_protein = layer['attention'](
protein, smiles, smiles
)[0]
protein = layer['norm1'](protein + attended_protein)
protein = layer['norm2'](protein + layer['ffn'](protein))
# SMILES attending to protein
attended_smiles = layer['attention'](
smiles, protein, protein
)[0]
smiles = layer['norm1'](smiles + attended_smiles)
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
# Get sequence-level representations
protein_pool = torch.mean(protein, dim=0)
smiles_pool = torch.mean(smiles, dim=0)
# Concatenate both representations
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
# Shared features
shared_features = self.shared_head(combined)
regression_output = self.regression_head(shared_features)
classification_logits = self.classification_head(shared_features)
return regression_output, classification_logits
class BindingAffinity:
def __init__(self, prot_seq, device, model_type='PeptideCLM'):
super().__init__()
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device)
self.model = ImprovedBindingPredictor(smiles_dim=768).to(device)
checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/binding-affinity.pt', weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
self.esm_model.to(device)
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
data = [("target", prot_seq)]
# get tokenized protein
_, _, prot_tokens = self.prot_tokenizer(data)
prot_tokens = prot_tokens.to(device)
with torch.no_grad():
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
prot_emb = results["representations"][33]
self.prot_emb = prot_emb[0]
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True).to(device)
def forward(self, x):
with torch.no_grad():
scores = []
pep_emb = self.pep_model(input_ids=x, output_hidden_states=True).last_hidden_state.mean(dim=1, keepdim=True)
for pep in pep_emb:
score, logits = self.model.forward(self.prot_emb, pep)
scores.append(score.item() / 10)
return torch.tensor(scores)
def __call__(self, x):
return self.forward(x)