import sys import os sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') import xgboost as xgb import torch import numpy as np from transformers import AutoModelForMaskedLM from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer import warnings import numpy as np import esm import torch.nn as nn from rdkit import Chem, rdBase, DataStructs rdBase.DisableLog('rdApp.error') warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) class Analyzer: def __init__(self, tokenizer): self.tokenizer = tokenizer def get_scores(self, x): """Check if the SMILES represents a peptide structure""" results = [] smiles_list = self.tokenizer.batch_decode(x) for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is None: results.append(0) continue # Look for peptide bonds: NC(=O) pattern peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)') # Look for N-methylated peptide bonds: N(C)C(=O) pattern n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)') if mol.HasSubstructMatch(peptide_bond_pattern) or mol.HasSubstructMatch(n_methyl_pattern): results.append(1) else: results.append(0) return torch.tensor(results) def __call__(self, x): scores = self.get_scores(x) return torch.tensor(scores) class Hemolysis: def __init__(self, device): self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/hemolysis-xgboost.json') self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) def get_scores(self, x): scores = np.ones(len(x)) features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) if len(features) == 0: return scores features = np.nan_to_num(features, nan=0.) features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) features = xgb.DMatrix(features) probs = self.predictor.predict(features) # return the probability of it being not hemolytic return scores - probs def __call__(self, x): scores = self.get_scores(x) return torch.tensor(scores) class Nonfouling: def __init__(self, device): self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/nonfouling-xgboost.json') self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) def get_scores(self, x): scores = np.zeros(len(x)) features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) if len(features) == 0: return scores features = np.nan_to_num(features, nan=0.) features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) features = xgb.DMatrix(features) scores = self.predictor.predict(features) # return the probability of it being not hemolytic return scores def __call__(self, x): scores = self.get_scores(x) return torch.tensor(scores) class Solubility: def __init__(self, device): self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json') self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) def get_scores(self, x): scores = np.zeros(len(x)) features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) if len(features) == 0: return scores features = np.nan_to_num(features, nan=0.) features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) features = xgb.DMatrix(features) scores = self.predictor.predict(features) return scores def __call__(self, x): scores = self.get_scores(x) return torch.tensor(scores) class ImprovedBindingPredictor(nn.Module): def __init__(self, esm_dim=1280, smiles_dim=768, hidden_dim=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() # Define binding thresholds self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM # Project to same dimension self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) self.protein_projection = nn.Linear(esm_dim, hidden_dim) self.protein_norm = nn.LayerNorm(hidden_dim) self.smiles_norm = nn.LayerNorm(hidden_dim) # Cross attention blocks with layer norm self.cross_attention_layers = nn.ModuleList([ nn.ModuleDict({ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), 'norm1': nn.LayerNorm(hidden_dim), 'ffn': nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim) ), 'norm2': nn.LayerNorm(hidden_dim) }) for _ in range(n_layers) ]) # Prediction heads self.shared_head = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(dropout), ) # Regression head self.regression_head = nn.Linear(hidden_dim, 1) # Classification head (3 classes: tight, medium, loose binding) self.classification_head = nn.Linear(hidden_dim, 3) def get_binding_class(self, affinity): """Convert affinity values to class indices 0: tight binding (>= 7.5) 1: medium binding (6.0-7.5) 2: weak binding (< 6.0) """ if isinstance(affinity, torch.Tensor): tight_mask = affinity >= self.tight_threshold weak_mask = affinity < self.weak_threshold medium_mask = ~(tight_mask | weak_mask) classes = torch.zeros_like(affinity, dtype=torch.long) classes[medium_mask] = 1 classes[weak_mask] = 2 return classes else: if affinity >= self.tight_threshold: return 0 # tight binding elif affinity < self.weak_threshold: return 2 # weak binding else: return 1 # medium binding def forward(self, protein_emb, smiles_emb): protein = self.protein_norm(self.protein_projection(protein_emb)) smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) #protein = protein.transpose(0, 1) #smiles = smiles.transpose(0, 1) # Cross attention layers for layer in self.cross_attention_layers: # Protein attending to SMILES attended_protein = layer['attention']( protein, smiles, smiles )[0] protein = layer['norm1'](protein + attended_protein) protein = layer['norm2'](protein + layer['ffn'](protein)) # SMILES attending to protein attended_smiles = layer['attention']( smiles, protein, protein )[0] smiles = layer['norm1'](smiles + attended_smiles) smiles = layer['norm2'](smiles + layer['ffn'](smiles)) # Get sequence-level representations protein_pool = torch.mean(protein, dim=0) smiles_pool = torch.mean(smiles, dim=0) # Concatenate both representations combined = torch.cat([protein_pool, smiles_pool], dim=-1) # Shared features shared_features = self.shared_head(combined) regression_output = self.regression_head(shared_features) classification_logits = self.classification_head(shared_features) return regression_output, classification_logits class BindingAffinity: def __init__(self, prot_seq, device, model_type='PeptideCLM'): super().__init__() self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) self.model = ImprovedBindingPredictor(smiles_dim=768).to(device) checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/binding-affinity.pt', weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model self.esm_model.to(device) self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer data = [("target", prot_seq)] # get tokenized protein _, _, prot_tokens = self.prot_tokenizer(data) prot_tokens = prot_tokens.to(device) with torch.no_grad(): results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2 prot_emb = results["representations"][33] self.prot_emb = prot_emb[0] self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True).to(device) def forward(self, x): with torch.no_grad(): scores = [] pep_emb = self.pep_model(input_ids=x, output_hidden_states=True).last_hidden_state.mean(dim=1, keepdim=True) for pep in pep_emb: score, logits = self.model.forward(self.prot_emb, pep) scores.append(score.item() / 10) return torch.tensor(scores) def __call__(self, x): return self.forward(x)