|
""" |
|
SignalSeeker: Protein Signal Peptide Prediction |
|
""" |
|
|
|
import os |
|
import json |
|
import joblib |
|
import numpy as np |
|
import torch |
|
import re |
|
from transformers import BertModel, BertTokenizer |
|
from typing import Dict, List, Union, Any |
|
|
|
class SignalSeekerPredictor: |
|
"""SignalSeeker protein signal peptide predictor""" |
|
|
|
def __init__(self, model_dir: str): |
|
self.model_dir = model_dir |
|
self.models = {} |
|
self.scaler = None |
|
self.protbert_model = None |
|
self.protbert_tokenizer = None |
|
self.device = None |
|
self.config = {} |
|
|
|
self._load_models() |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name_or_path: str): |
|
"""Load model from Hugging Face or local path""" |
|
return cls(model_name_or_path) |
|
|
|
def _load_models(self): |
|
"""Load all model components""" |
|
|
|
|
|
config_path = os.path.join(self.model_dir, "config.json") |
|
if os.path.exists(config_path): |
|
with open(config_path, 'r') as f: |
|
self.config = json.load(f) |
|
|
|
|
|
scaler_path = os.path.join(self.model_dir, "scaler.pkl") |
|
if os.path.exists(scaler_path): |
|
self.scaler = joblib.load(scaler_path) |
|
|
|
|
|
models_dir = os.path.join(self.model_dir, "models") |
|
if os.path.exists(models_dir): |
|
for model_file in os.listdir(models_dir): |
|
if model_file.endswith('.pkl'): |
|
model_name = model_file.replace('.pkl', '').replace('_', ' ') |
|
model_path = os.path.join(models_dir, model_file) |
|
self.models[model_name] = joblib.load(model_path) |
|
|
|
|
|
self._load_protbert() |
|
|
|
def _load_protbert(self): |
|
"""Load ProtBERT for feature extraction""" |
|
try: |
|
self.protbert_tokenizer = BertTokenizer.from_pretrained( |
|
"Rostlab/prot_bert", do_lower_case=False |
|
) |
|
self.protbert_model = BertModel.from_pretrained("Rostlab/prot_bert") |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.protbert_model = self.protbert_model.to(self.device) |
|
self.protbert_model.eval() |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load ProtBERT: {e}") |
|
|
|
def _extract_features(self, sequences: List[str]) -> np.ndarray: |
|
"""Extract ProtBERT features from sequences""" |
|
|
|
n_length = self.config.get('n_terminal_length', 50) |
|
|
|
|
|
n_terminal_seqs = [] |
|
for seq in sequences: |
|
|
|
clean_seq = re.sub(r"[^ACDEFGHIKLMNPQRSTVWXY]", "X", seq) |
|
|
|
|
|
if len(clean_seq) >= n_length: |
|
n_terminal_seqs.append(clean_seq[:n_length]) |
|
else: |
|
padded = clean_seq + 'X' * (n_length - len(clean_seq)) |
|
n_terminal_seqs.append(padded) |
|
|
|
|
|
all_embeddings = [] |
|
batch_size = 8 |
|
|
|
for i in range(0, len(n_terminal_seqs), batch_size): |
|
batch = n_terminal_seqs[i:i+batch_size] |
|
|
|
|
|
prepped = [" ".join(list(seq)) for seq in batch] |
|
|
|
|
|
inputs = self.protbert_tokenizer.batch_encode_plus( |
|
prepped, |
|
add_special_tokens=True, |
|
padding="longest", |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt" |
|
) |
|
|
|
input_ids = inputs['input_ids'].to(self.device) |
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.protbert_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
for j, seq in enumerate(batch): |
|
seq_length = min(len(seq), 510) |
|
if seq_length > 0: |
|
seq_emb = outputs.last_hidden_state[j, 1:seq_length+1] |
|
protein_emb = seq_emb.mean(dim=0).cpu().numpy() |
|
else: |
|
protein_emb = outputs.last_hidden_state[j, 0].cpu().numpy() |
|
|
|
all_embeddings.append(protein_emb) |
|
|
|
return np.array(all_embeddings) |
|
|
|
def predict(self, sequence: str) -> Dict[str, Any]: |
|
"""Predict signal peptide for a single sequence""" |
|
|
|
|
|
X = self._extract_features([sequence]) |
|
|
|
|
|
if self.scaler: |
|
X = self.scaler.transform(X) |
|
|
|
|
|
predictions = [] |
|
probabilities = [] |
|
|
|
for model_name, model in self.models.items(): |
|
try: |
|
pred = model.predict(X)[0] |
|
prob = model.predict_proba(X)[0, 1] |
|
predictions.append(pred) |
|
probabilities.append(prob) |
|
except Exception: |
|
continue |
|
|
|
|
|
if probabilities: |
|
ensemble_prob = np.mean(probabilities) |
|
ensemble_pred = ensemble_prob > 0.5 |
|
|
|
|
|
if ensemble_prob > 0.8 or ensemble_prob < 0.2: |
|
confidence = "High" |
|
elif ensemble_prob > 0.7 or ensemble_prob < 0.3: |
|
confidence = "Medium" |
|
else: |
|
confidence = "Low" |
|
else: |
|
ensemble_prob = 0.0 |
|
ensemble_pred = False |
|
confidence = "Error" |
|
|
|
return { |
|
"has_signal_peptide": bool(ensemble_pred), |
|
"probability": float(ensemble_prob), |
|
"confidence": confidence, |
|
"sequence_length": len(sequence) |
|
} |
|
|
|
def predict_batch(self, sequences: Union[List[str], Dict[str, str]]) -> Dict[str, Dict[str, Any]]: |
|
"""Predict signal peptides for multiple sequences""" |
|
|
|
if isinstance(sequences, list): |
|
sequences = {f"seq_{i}": seq for i, seq in enumerate(sequences)} |
|
|
|
results = {} |
|
for name, sequence in sequences.items(): |
|
results[name] = self.predict(sequence) |
|
|
|
return results |
|
|
|
def pipeline(sequences: Union[str, List[str], Dict[str, str]], **kwargs) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]: |
|
"""Hugging Face pipeline interface""" |
|
|
|
|
|
model_dir = kwargs.get('model_dir', '.') |
|
|
|
predictor = SignalSeekerPredictor(model_dir) |
|
|
|
if isinstance(sequences, str): |
|
return predictor.predict(sequences) |
|
else: |
|
return predictor.predict_batch(sequences) |
|
|