hcoops commited on
Commit
89aeb83
·
verified ·
1 Parent(s): 3756173

Upload signalseeker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. signalseeker.py +203 -0
signalseeker.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SignalSeeker: Protein Signal Peptide Prediction
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import joblib
8
+ import numpy as np
9
+ import torch
10
+ import re
11
+ from transformers import BertModel, BertTokenizer
12
+ from typing import Dict, List, Union, Any
13
+
14
+ class SignalSeekerPredictor:
15
+ """SignalSeeker protein signal peptide predictor"""
16
+
17
+ def __init__(self, model_dir: str):
18
+ self.model_dir = model_dir
19
+ self.models = {}
20
+ self.scaler = None
21
+ self.protbert_model = None
22
+ self.protbert_tokenizer = None
23
+ self.device = None
24
+ self.config = {}
25
+
26
+ self._load_models()
27
+
28
+ @classmethod
29
+ def from_pretrained(cls, model_name_or_path: str):
30
+ """Load model from Hugging Face or local path"""
31
+ return cls(model_name_or_path)
32
+
33
+ def _load_models(self):
34
+ """Load all model components"""
35
+
36
+ # Load config
37
+ config_path = os.path.join(self.model_dir, "config.json")
38
+ if os.path.exists(config_path):
39
+ with open(config_path, 'r') as f:
40
+ self.config = json.load(f)
41
+
42
+ # Load scaler
43
+ scaler_path = os.path.join(self.model_dir, "scaler.pkl")
44
+ if os.path.exists(scaler_path):
45
+ self.scaler = joblib.load(scaler_path)
46
+
47
+ # Load ensemble models
48
+ models_dir = os.path.join(self.model_dir, "models")
49
+ if os.path.exists(models_dir):
50
+ for model_file in os.listdir(models_dir):
51
+ if model_file.endswith('.pkl'):
52
+ model_name = model_file.replace('.pkl', '').replace('_', ' ')
53
+ model_path = os.path.join(models_dir, model_file)
54
+ self.models[model_name] = joblib.load(model_path)
55
+
56
+ # Load ProtBERT
57
+ self._load_protbert()
58
+
59
+ def _load_protbert(self):
60
+ """Load ProtBERT for feature extraction"""
61
+ try:
62
+ self.protbert_tokenizer = BertTokenizer.from_pretrained(
63
+ "Rostlab/prot_bert", do_lower_case=False
64
+ )
65
+ self.protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
66
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ self.protbert_model = self.protbert_model.to(self.device)
68
+ self.protbert_model.eval()
69
+ except Exception as e:
70
+ raise RuntimeError(f"Failed to load ProtBERT: {e}")
71
+
72
+ def _extract_features(self, sequences: List[str]) -> np.ndarray:
73
+ """Extract ProtBERT features from sequences"""
74
+
75
+ n_length = self.config.get('n_terminal_length', 50)
76
+
77
+ # Prepare N-terminal sequences
78
+ n_terminal_seqs = []
79
+ for seq in sequences:
80
+ # Clean sequence
81
+ clean_seq = re.sub(r"[^ACDEFGHIKLMNPQRSTVWXY]", "X", seq)
82
+
83
+ # Extract N-terminal
84
+ if len(clean_seq) >= n_length:
85
+ n_terminal_seqs.append(clean_seq[:n_length])
86
+ else:
87
+ padded = clean_seq + 'X' * (n_length - len(clean_seq))
88
+ n_terminal_seqs.append(padded)
89
+
90
+ # Extract embeddings
91
+ all_embeddings = []
92
+ batch_size = 8
93
+
94
+ for i in range(0, len(n_terminal_seqs), batch_size):
95
+ batch = n_terminal_seqs[i:i+batch_size]
96
+
97
+ # Preprocess for ProtBERT
98
+ prepped = [" ".join(list(seq)) for seq in batch]
99
+
100
+ # Tokenize
101
+ inputs = self.protbert_tokenizer.batch_encode_plus(
102
+ prepped,
103
+ add_special_tokens=True,
104
+ padding="longest",
105
+ truncation=True,
106
+ max_length=512,
107
+ return_tensors="pt"
108
+ )
109
+
110
+ input_ids = inputs['input_ids'].to(self.device)
111
+ attention_mask = inputs['attention_mask'].to(self.device)
112
+
113
+ # Generate embeddings
114
+ with torch.no_grad():
115
+ outputs = self.protbert_model(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask
118
+ )
119
+
120
+ # Extract per-protein embeddings
121
+ for j, seq in enumerate(batch):
122
+ seq_length = min(len(seq), 510)
123
+ if seq_length > 0:
124
+ seq_emb = outputs.last_hidden_state[j, 1:seq_length+1]
125
+ protein_emb = seq_emb.mean(dim=0).cpu().numpy()
126
+ else:
127
+ protein_emb = outputs.last_hidden_state[j, 0].cpu().numpy()
128
+
129
+ all_embeddings.append(protein_emb)
130
+
131
+ return np.array(all_embeddings)
132
+
133
+ def predict(self, sequence: str) -> Dict[str, Any]:
134
+ """Predict signal peptide for a single sequence"""
135
+
136
+ # Extract features
137
+ X = self._extract_features([sequence])
138
+
139
+ # Apply scaling
140
+ if self.scaler:
141
+ X = self.scaler.transform(X)
142
+
143
+ # Get predictions from all models
144
+ predictions = []
145
+ probabilities = []
146
+
147
+ for model_name, model in self.models.items():
148
+ try:
149
+ pred = model.predict(X)[0]
150
+ prob = model.predict_proba(X)[0, 1]
151
+ predictions.append(pred)
152
+ probabilities.append(prob)
153
+ except Exception:
154
+ continue
155
+
156
+ # Ensemble decision
157
+ if probabilities:
158
+ ensemble_prob = np.mean(probabilities)
159
+ ensemble_pred = ensemble_prob > 0.5
160
+
161
+ # Confidence assessment
162
+ if ensemble_prob > 0.8 or ensemble_prob < 0.2:
163
+ confidence = "High"
164
+ elif ensemble_prob > 0.7 or ensemble_prob < 0.3:
165
+ confidence = "Medium"
166
+ else:
167
+ confidence = "Low"
168
+ else:
169
+ ensemble_prob = 0.0
170
+ ensemble_pred = False
171
+ confidence = "Error"
172
+
173
+ return {
174
+ "has_signal_peptide": bool(ensemble_pred),
175
+ "probability": float(ensemble_prob),
176
+ "confidence": confidence,
177
+ "sequence_length": len(sequence)
178
+ }
179
+
180
+ def predict_batch(self, sequences: Union[List[str], Dict[str, str]]) -> Dict[str, Dict[str, Any]]:
181
+ """Predict signal peptides for multiple sequences"""
182
+
183
+ if isinstance(sequences, list):
184
+ sequences = {f"seq_{i}": seq for i, seq in enumerate(sequences)}
185
+
186
+ results = {}
187
+ for name, sequence in sequences.items():
188
+ results[name] = self.predict(sequence)
189
+
190
+ return results
191
+
192
+ def pipeline(sequences: Union[str, List[str], Dict[str, str]], **kwargs) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]:
193
+ """Hugging Face pipeline interface"""
194
+
195
+ # This would be automatically set by HF
196
+ model_dir = kwargs.get('model_dir', '.')
197
+
198
+ predictor = SignalSeekerPredictor(model_dir)
199
+
200
+ if isinstance(sequences, str):
201
+ return predictor.predict(sequences)
202
+ else:
203
+ return predictor.predict_batch(sequences)