signalseeker / signalseeker.py
hcoops's picture
Upload signalseeker.py with huggingface_hub
89aeb83 verified
"""
SignalSeeker: Protein Signal Peptide Prediction
"""
import os
import json
import joblib
import numpy as np
import torch
import re
from transformers import BertModel, BertTokenizer
from typing import Dict, List, Union, Any
class SignalSeekerPredictor:
"""SignalSeeker protein signal peptide predictor"""
def __init__(self, model_dir: str):
self.model_dir = model_dir
self.models = {}
self.scaler = None
self.protbert_model = None
self.protbert_tokenizer = None
self.device = None
self.config = {}
self._load_models()
@classmethod
def from_pretrained(cls, model_name_or_path: str):
"""Load model from Hugging Face or local path"""
return cls(model_name_or_path)
def _load_models(self):
"""Load all model components"""
# Load config
config_path = os.path.join(self.model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load scaler
scaler_path = os.path.join(self.model_dir, "scaler.pkl")
if os.path.exists(scaler_path):
self.scaler = joblib.load(scaler_path)
# Load ensemble models
models_dir = os.path.join(self.model_dir, "models")
if os.path.exists(models_dir):
for model_file in os.listdir(models_dir):
if model_file.endswith('.pkl'):
model_name = model_file.replace('.pkl', '').replace('_', ' ')
model_path = os.path.join(models_dir, model_file)
self.models[model_name] = joblib.load(model_path)
# Load ProtBERT
self._load_protbert()
def _load_protbert(self):
"""Load ProtBERT for feature extraction"""
try:
self.protbert_tokenizer = BertTokenizer.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False
)
self.protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.protbert_model = self.protbert_model.to(self.device)
self.protbert_model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load ProtBERT: {e}")
def _extract_features(self, sequences: List[str]) -> np.ndarray:
"""Extract ProtBERT features from sequences"""
n_length = self.config.get('n_terminal_length', 50)
# Prepare N-terminal sequences
n_terminal_seqs = []
for seq in sequences:
# Clean sequence
clean_seq = re.sub(r"[^ACDEFGHIKLMNPQRSTVWXY]", "X", seq)
# Extract N-terminal
if len(clean_seq) >= n_length:
n_terminal_seqs.append(clean_seq[:n_length])
else:
padded = clean_seq + 'X' * (n_length - len(clean_seq))
n_terminal_seqs.append(padded)
# Extract embeddings
all_embeddings = []
batch_size = 8
for i in range(0, len(n_terminal_seqs), batch_size):
batch = n_terminal_seqs[i:i+batch_size]
# Preprocess for ProtBERT
prepped = [" ".join(list(seq)) for seq in batch]
# Tokenize
inputs = self.protbert_tokenizer.batch_encode_plus(
prepped,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=512,
return_tensors="pt"
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
# Generate embeddings
with torch.no_grad():
outputs = self.protbert_model(
input_ids=input_ids,
attention_mask=attention_mask
)
# Extract per-protein embeddings
for j, seq in enumerate(batch):
seq_length = min(len(seq), 510)
if seq_length > 0:
seq_emb = outputs.last_hidden_state[j, 1:seq_length+1]
protein_emb = seq_emb.mean(dim=0).cpu().numpy()
else:
protein_emb = outputs.last_hidden_state[j, 0].cpu().numpy()
all_embeddings.append(protein_emb)
return np.array(all_embeddings)
def predict(self, sequence: str) -> Dict[str, Any]:
"""Predict signal peptide for a single sequence"""
# Extract features
X = self._extract_features([sequence])
# Apply scaling
if self.scaler:
X = self.scaler.transform(X)
# Get predictions from all models
predictions = []
probabilities = []
for model_name, model in self.models.items():
try:
pred = model.predict(X)[0]
prob = model.predict_proba(X)[0, 1]
predictions.append(pred)
probabilities.append(prob)
except Exception:
continue
# Ensemble decision
if probabilities:
ensemble_prob = np.mean(probabilities)
ensemble_pred = ensemble_prob > 0.5
# Confidence assessment
if ensemble_prob > 0.8 or ensemble_prob < 0.2:
confidence = "High"
elif ensemble_prob > 0.7 or ensemble_prob < 0.3:
confidence = "Medium"
else:
confidence = "Low"
else:
ensemble_prob = 0.0
ensemble_pred = False
confidence = "Error"
return {
"has_signal_peptide": bool(ensemble_pred),
"probability": float(ensemble_prob),
"confidence": confidence,
"sequence_length": len(sequence)
}
def predict_batch(self, sequences: Union[List[str], Dict[str, str]]) -> Dict[str, Dict[str, Any]]:
"""Predict signal peptides for multiple sequences"""
if isinstance(sequences, list):
sequences = {f"seq_{i}": seq for i, seq in enumerate(sequences)}
results = {}
for name, sequence in sequences.items():
results[name] = self.predict(sequence)
return results
def pipeline(sequences: Union[str, List[str], Dict[str, str]], **kwargs) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]:
"""Hugging Face pipeline interface"""
# This would be automatically set by HF
model_dir = kwargs.get('model_dir', '.')
predictor = SignalSeekerPredictor(model_dir)
if isinstance(sequences, str):
return predictor.predict(sequences)
else:
return predictor.predict_batch(sequences)