|
|
import argparse |
|
|
import re |
|
|
import random |
|
|
from collections import Counter |
|
|
import csv |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm |
|
|
from train import MDLMLightningModule, PeptideAnalyzer |
|
|
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
import numpy as np |
|
|
|
|
|
from smiles_classifiers import Analyzer, Hemolysis, Nonfouling, Solubility, BindingAffinity |
|
|
|
|
|
def peptide_bond_mask(smiles_list): |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
max_seq_length = 1035 |
|
|
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
|
|
|
bond_patterns = [ |
|
|
(r'OC\(=O\)', 'ester'), |
|
|
(r'N\(C\)C\(=O\)', 'n_methyl'), |
|
|
(r'N[12]C\(=O\)', 'peptide'), |
|
|
(r'NC\(=O\)', 'peptide'), |
|
|
(r'C\(=O\)N\(C\)', 'n_methyl'), |
|
|
(r'C\(=O\)N[12]?', 'peptide') |
|
|
] |
|
|
|
|
|
for batch_idx, smiles in enumerate(smiles_list): |
|
|
positions = [] |
|
|
used = set() |
|
|
|
|
|
|
|
|
for pattern, bond_type in bond_patterns: |
|
|
for match in re.finditer(pattern, smiles): |
|
|
if not any(p in range(match.start(), match.end()) for p in used): |
|
|
positions.append({ |
|
|
'start': match.start(), |
|
|
'end': match.end(), |
|
|
'type': bond_type, |
|
|
'pattern': match.group() |
|
|
}) |
|
|
used.update(range(match.start(), match.end())) |
|
|
|
|
|
|
|
|
for pos in positions: |
|
|
mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
def peptide_token_mask(smiles_list, token_lists): |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
token_seq_length = max(len(tokens) for tokens in token_lists) |
|
|
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
|
|
atomwise_masks = peptide_bond_mask(smiles_list) |
|
|
|
|
|
|
|
|
for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
|
|
token_seq = token_lists[batch_idx] |
|
|
atom_idx = 0 |
|
|
|
|
|
for token_idx, token in enumerate(token_seq): |
|
|
if token_idx != 0 and token_idx != len(token_seq) - 1: |
|
|
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
|
|
tokenized_masks[batch_idx][token_idx] = 1 |
|
|
atom_idx += len(token) |
|
|
|
|
|
return tokenized_masks |
|
|
|
|
|
class MOGGenerator: |
|
|
def __init__(self, model, device, objectives, args): |
|
|
self.model = model |
|
|
self.device = device |
|
|
self.objectives = objectives |
|
|
self.args = args |
|
|
self.num_objectives = len(objectives) |
|
|
self.peptide_analyzer = PeptideAnalyzer() |
|
|
self.invalid = [0, 1, 2, 3, 4, 585, 586] |
|
|
|
|
|
def generate_x0(self, n_steps=16, temperature=1.0): |
|
|
print("Starting initial SMILES generation...") |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
x = torch.randint( |
|
|
0, |
|
|
self.model.tokenizer.vocab_size, |
|
|
(args.num_samples, args.gen_len), |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(n_steps), desc="Flow Matching Steps"): |
|
|
t_curr = time_steps[i] |
|
|
t_next = time_steps[i+1] |
|
|
|
|
|
|
|
|
t_tensor = torch.full((args.num_samples,), t_curr, device=self.device) |
|
|
|
|
|
|
|
|
logits = self.model(x, t_tensor) |
|
|
logits = logits / temperature |
|
|
logits[:, :, 586] = -1000 |
|
|
|
|
|
pred_x1 = torch.argmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
if i == n_steps - 1: |
|
|
x = pred_x1 |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
noise_prob = 1.0 - t_next |
|
|
mask = torch.rand(x.shape, device=self.device) < noise_prob |
|
|
|
|
|
|
|
|
noise = torch.randint( |
|
|
0, |
|
|
self.model.tokenizer.vocab_size, |
|
|
x.shape, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
x = torch.where(mask, noise, pred_x1) |
|
|
|
|
|
|
|
|
generated_sequences = self.model.tokenizer.batch_decode(x) |
|
|
|
|
|
|
|
|
validities = [] |
|
|
for seq in generated_sequences: |
|
|
validities.append(self.peptide_analyzer.is_peptide(seq)) |
|
|
print(seq) |
|
|
print(f"Initial Sequence Validity: {validities}") |
|
|
return x |
|
|
|
|
|
def _get_scores(self, x_batch): |
|
|
"""Calculates the normalized scores for a batch of sequences.""" |
|
|
scores = [] |
|
|
for obj_func in self.objectives: |
|
|
scores.append(obj_func(x_batch.to(self.device))) |
|
|
return torch.stack(scores, dim=0).to(self.device) |
|
|
|
|
|
def _barker_g(self, u): |
|
|
"""Barker balancing function.""" |
|
|
return u / (1 + u) |
|
|
|
|
|
def validity(self, x): |
|
|
sampled_sequences = self.model.tokenizer.batch_decode(x) |
|
|
|
|
|
return [1.0 if self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences] |
|
|
|
|
|
def generate(self): |
|
|
"""Main generation loop.""" |
|
|
shape = (self.args.num_samples, self.args.gen_len) |
|
|
x = self.generate_x0() |
|
|
print(x) |
|
|
|
|
|
if args.weights is None: |
|
|
|
|
|
|
|
|
weights = torch.tensor([1] + [1/(self.num_objectives-1)] * (self.num_objectives-1), device=self.device).view(-1,1) |
|
|
else: |
|
|
weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1) |
|
|
if len(weights) != self.num_objectives: |
|
|
raise ValueError("Number of weights must match number of objectives.") |
|
|
print(f"Weights: {weights}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for t in tqdm(range(self.args.optimization_steps), desc="MOG Generation"): |
|
|
improved = False |
|
|
|
|
|
eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (self.args.optimization_steps - 1)) |
|
|
|
|
|
mut_idx = random.randint(1, self.args.gen_len-2) |
|
|
|
|
|
|
|
|
|
|
|
generation_step = t % self.args.optimization_steps |
|
|
time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device) |
|
|
|
|
|
|
|
|
logits = self.model(x, time_t) |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
pos_probs = probs[:, mut_idx, :] |
|
|
pos_probs[:, x[:, mut_idx]] = 0 |
|
|
pos_probs[:, self.invalid] = 0 |
|
|
|
|
|
|
|
|
sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
remove_mask = cumulative_probs > self.args.top_p |
|
|
remove_mask[..., 1:] = remove_mask[..., :-1].clone() |
|
|
remove_mask[..., 0] = 0 |
|
|
|
|
|
|
|
|
candidate_tokens_list = [] |
|
|
for i in range(self.args.num_samples): |
|
|
sample_mask = remove_mask[i] |
|
|
candidates = sorted_indices[i, ~sample_mask] |
|
|
candidate_tokens_list.append(candidates) |
|
|
|
|
|
|
|
|
current_scores = self._get_scores(x) |
|
|
w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values) |
|
|
|
|
|
if t == 0: |
|
|
print(f"Initial Scores: {current_scores}") |
|
|
|
|
|
|
|
|
final_proposal_tokens = [] |
|
|
for i in range(self.args.num_samples): |
|
|
candidates = candidate_tokens_list[i] |
|
|
candidates = torch.tensor([token for token in candidates if token not in self.invalid], device=candidates.device) |
|
|
if len(candidates) >= 200: |
|
|
candidates = candidates[:200] |
|
|
num_candidates = len(candidates) |
|
|
|
|
|
|
|
|
x_prop_batch = x[i].repeat(num_candidates, 1) |
|
|
x_prop_batch[:, mut_idx] = candidates |
|
|
|
|
|
|
|
|
proposal_scores = self._get_scores(x_prop_batch) |
|
|
proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values |
|
|
w_proposal = torch.exp(eta_t * proposal_s_omega) |
|
|
|
|
|
|
|
|
redi_probs = pos_probs[i, candidates] |
|
|
|
|
|
|
|
|
tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i]) |
|
|
|
|
|
|
|
|
final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9) |
|
|
|
|
|
index = torch.multinomial(final_probs, 1).item() |
|
|
|
|
|
if (random.uniform(0,1) > 1 and proposal_scores[:,index][0] == 1) or torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]): |
|
|
final_token = candidates[index] |
|
|
print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}") |
|
|
print(f"Previous Scores: {current_scores[:,i]}") |
|
|
|
|
|
print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}") |
|
|
print(f"New Scores: {proposal_scores[:,index]}") |
|
|
improved = True |
|
|
|
|
|
else: |
|
|
final_token = x[i][mut_idx] |
|
|
|
|
|
final_proposal_tokens.append(final_token) |
|
|
|
|
|
|
|
|
x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens) |
|
|
if improved: |
|
|
print(self.model.tokenizer.batch_decode(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
target = args.target |
|
|
tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') |
|
|
|
|
|
analyzer_model = Analyzer(tokenizer) |
|
|
hemolysis_model = Hemolysis(device) |
|
|
nonfouling_model = Nonfouling(device) |
|
|
solubility_model = Solubility(device) |
|
|
permeability_model = Permeability(device) |
|
|
affinity_model = BindingAffinity(target, device) |
|
|
|
|
|
|
|
|
OBJECTIVE_FUNCTIONS = [analyzer_model, hemolysis_model, nonfouling_model, solubility_model, affinity_model] |
|
|
|
|
|
|
|
|
print(f"Loading model from checkpoint: {args.checkpoint}") |
|
|
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
|
|
model_hparams = checkpoint["hyper_parameters"]["args"] |
|
|
|
|
|
model = MDLMLightningModule.load_from_checkpoint( |
|
|
args.checkpoint, |
|
|
args=model_hparams, |
|
|
tokenizer=tokenizer, |
|
|
map_location=device, |
|
|
strict=False |
|
|
) |
|
|
model.to(device) |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args) |
|
|
|
|
|
for _ in range(args.num_batches): |
|
|
generated_tokens = mog_generator.generate() |
|
|
final_scores = mog_generator._get_scores(generated_tokens.detach()).detach().cpu().numpy() |
|
|
sequence_str = tokenizer.batch_decode(generated_tokens) |
|
|
|
|
|
print(sequence_str) |
|
|
print(final_scores) |
|
|
|
|
|
print("Generation complete.") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).") |
|
|
|
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.") |
|
|
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.") |
|
|
parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.") |
|
|
parser.add_argument("--output_file", type=str, default="./smiles.txt", help="File to save the generated sequences.") |
|
|
parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.") |
|
|
parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.") |
|
|
parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).") |
|
|
parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.") |
|
|
parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.") |
|
|
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.") |
|
|
|
|
|
parser.add_argument("--target", type=str, required=True) |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|