AReUReDi / smiles /moo.py
Tong Chen
add files
295b1cd
import argparse
import re
import random
from collections import Counter
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from train import MDLMLightningModule, PeptideAnalyzer
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
import numpy as np
from smiles_classifiers import Analyzer, Hemolysis, Nonfouling, Solubility, BindingAffinity
def peptide_bond_mask(smiles_list):
# Initialize the batch mask
batch_size = len(smiles_list)
max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
bond_patterns = [
(r'OC\(=O\)', 'ester'),
(r'N\(C\)C\(=O\)', 'n_methyl'),
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
(r'C\(=O\)N\(C\)', 'n_methyl'),
(r'C\(=O\)N[12]?', 'peptide')
]
for batch_idx, smiles in enumerate(smiles_list):
positions = []
used = set()
# Identify bonds
for pattern, bond_type in bond_patterns:
for match in re.finditer(pattern, smiles):
if not any(p in range(match.start(), match.end()) for p in used):
positions.append({
'start': match.start(),
'end': match.end(),
'type': bond_type,
'pattern': match.group()
})
used.update(range(match.start(), match.end()))
# Update the mask for the current SMILES
for pos in positions:
mask[batch_idx, pos['start']:pos['end']] = 1
return mask
def peptide_token_mask(smiles_list, token_lists):
# Initialize the batch mask
batch_size = len(smiles_list)
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
atomwise_masks = peptide_bond_mask(smiles_list)
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
token_seq = token_lists[batch_idx]
atom_idx = 0
for token_idx, token in enumerate(token_seq):
if token_idx != 0 and token_idx != len(token_seq) - 1:
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
tokenized_masks[batch_idx][token_idx] = 1
atom_idx += len(token)
return tokenized_masks
class MOGGenerator:
def __init__(self, model, device, objectives, args):
self.model = model
self.device = device
self.objectives = objectives
self.args = args
self.num_objectives = len(objectives)
self.peptide_analyzer = PeptideAnalyzer()
self.invalid = [0, 1, 2, 3, 4, 585, 586]
def generate_x0(self, n_steps=16, temperature=1.0):
print("Starting initial SMILES generation...")
self.model.eval()
# 1. Start with a tensor of random tokens (pure noise at t=0)
x = torch.randint(
0,
self.model.tokenizer.vocab_size,
(args.num_samples, args.gen_len),
device=self.device
)
# 2. Define the time schedule for the forward process (0.0 to 1.0)
time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)
# 3. Iteratively follow the flow from noise to data
with torch.no_grad():
for i in tqdm(range(n_steps), desc="Flow Matching Steps"):
t_curr = time_steps[i]
t_next = time_steps[i+1]
# Prepare the current timestep tensor for the model
t_tensor = torch.full((args.num_samples,), t_curr, device=self.device)
# Get the model's prediction for the final clean sequence (at t=1)
logits = self.model(x, t_tensor)
logits = logits / temperature
logits[:, :, 586] = -1000
pred_x1 = torch.argmax(logits, dim=-1)
# On the last step, the result is the final prediction
if i == n_steps - 1:
x = pred_x1
break
# --- Construct the next state x_{t_next} ---
# The probability of a token being noise at time t_next is (1 - t_next).
noise_prob = 1.0 - t_next
mask = torch.rand(x.shape, device=self.device) < noise_prob
# Generate new random tokens for the noise positions
noise = torch.randint(
0,
self.model.tokenizer.vocab_size,
x.shape,
device=self.device
)
# Combine the final prediction with noise to form the next intermediate state
x = torch.where(mask, noise, pred_x1)
# 4. Decode the final token IDs into SMILES strings
generated_sequences = self.model.tokenizer.batch_decode(x)
# 5. Analyze the validity of the generated sequences
validities = []
for seq in generated_sequences:
validities.append(self.peptide_analyzer.is_peptide(seq))
print(seq)
print(f"Initial Sequence Validity: {validities}")
return x
def _get_scores(self, x_batch):
"""Calculates the normalized scores for a batch of sequences."""
scores = []
for obj_func in self.objectives:
scores.append(obj_func(x_batch.to(self.device)))
return torch.stack(scores, dim=0).to(self.device)
def _barker_g(self, u):
"""Barker balancing function."""
return u / (1 + u)
def validity(self, x):
sampled_sequences = self.model.tokenizer.batch_decode(x)
return [1.0 if self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences]
def generate(self):
"""Main generation loop."""
shape = (self.args.num_samples, self.args.gen_len)
x = self.generate_x0()
print(x)
if args.weights is None:
# The first weight is for peptide analyzer
# We need to ensure the SMILES sequences are valid peptides
weights = torch.tensor([1] + [1/(self.num_objectives-1)] * (self.num_objectives-1), device=self.device).view(-1,1)
else:
weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1)
if len(weights) != self.num_objectives:
raise ValueError("Number of weights must match number of objectives.")
print(f"Weights: {weights}")
with torch.no_grad():
for t in tqdm(range(self.args.optimization_steps), desc="MOG Generation"):
improved = False
# Anneal guidance strength
eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (self.args.optimization_steps - 1))
# Choose a random position to mutate
mut_idx = random.randint(1, self.args.gen_len-2)
# Determine the generation timestep
# We cycle through the timesteps to ensure all are visited
generation_step = t % self.args.optimization_steps
time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device)
# Get proposal distribution from ReDi model for the chosen position
logits = self.model(x, time_t)
probs = F.softmax(logits, dim=-1)
pos_probs = probs[:, mut_idx, :]
pos_probs[:, x[:, mut_idx]] = 0 # We don't evalute the same token
pos_probs[:, self.invalid] = 0
# Prune candidate vocabulary using top-p sampling
sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cumulative_probs > self.args.top_p
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
remove_mask[..., 0] = 0
# Get the set of candidate tokens for each sample in the batch
candidate_tokens_list = []
for i in range(self.args.num_samples):
sample_mask = remove_mask[i]
candidates = sorted_indices[i, ~sample_mask]
candidate_tokens_list.append(candidates)
# Get current scores
current_scores = self._get_scores(x)
w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values)
if t == 0:
print(f"Initial Scores: {current_scores}")
# Evaluate all candidate tokens for each sample
final_proposal_tokens = []
for i in range(self.args.num_samples):
candidates = candidate_tokens_list[i]
candidates = torch.tensor([token for token in candidates if token not in self.invalid], device=candidates.device)
if len(candidates) >= 200:
candidates = candidates[:200]
num_candidates = len(candidates)
# Create a batch of proposed sequences for the current sample
x_prop_batch = x[i].repeat(num_candidates, 1)
x_prop_batch[:, mut_idx] = candidates
# Evaluate all proposals
proposal_scores = self._get_scores(x_prop_batch)
proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values
w_proposal = torch.exp(eta_t * proposal_s_omega)
# Get ReDi probabilities for the candidates
redi_probs = pos_probs[i, candidates]
# Calculate unnormalized guided probabilities
tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i])
# Normalize and sample the final token
final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9)
index = torch.multinomial(final_probs, 1).item()
if (random.uniform(0,1) > 1 and proposal_scores[:,index][0] == 1) or torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]):
final_token = candidates[index]
print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}")
print(f"Previous Scores: {current_scores[:,i]}")
print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}")
print(f"New Scores: {proposal_scores[:,index]}")
improved = True
else:
final_token = x[i][mut_idx]
final_proposal_tokens.append(final_token)
# Update the sequences with the chosen tokens
x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens)
if improved:
print(self.model.tokenizer.batch_decode(x))
return x
# --- Main Execution ---
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
target = args.target
tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt')
analyzer_model = Analyzer(tokenizer)
hemolysis_model = Hemolysis(device)
nonfouling_model = Nonfouling(device)
solubility_model = Solubility(device)
permeability_model = Permeability(device)
affinity_model = BindingAffinity(target, device)
# List of all objective functions
OBJECTIVE_FUNCTIONS = [analyzer_model, hemolysis_model, nonfouling_model, solubility_model, affinity_model]
# --- Load Model ---
print(f"Loading model from checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
model_hparams = checkpoint["hyper_parameters"]["args"]
model = MDLMLightningModule.load_from_checkpoint(
args.checkpoint,
args=model_hparams,
tokenizer=tokenizer,
map_location=device,
strict=False
)
model.to(device)
print("Model loaded successfully.")
mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args)
for _ in range(args.num_batches):
generated_tokens = mog_generator.generate()
final_scores = mog_generator._get_scores(generated_tokens.detach()).detach().cpu().numpy()
sequence_str = tokenizer.batch_decode(generated_tokens)
print(sequence_str)
print(final_scores)
print("Generation complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.")
parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.")
parser.add_argument("--output_file", type=str, default="./smiles.txt", help="File to save the generated sequences.")
parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.")
parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.")
parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).")
parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.")
parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.")
parser.add_argument("--target", type=str, required=True)
args = parser.parse_args()
main(args)