File size: 7,091 Bytes
89aeb83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
SignalSeeker: Protein Signal Peptide Prediction
"""

import os
import json
import joblib
import numpy as np
import torch
import re
from transformers import BertModel, BertTokenizer
from typing import Dict, List, Union, Any

class SignalSeekerPredictor:
    """SignalSeeker protein signal peptide predictor"""
    
    def __init__(self, model_dir: str):
        self.model_dir = model_dir
        self.models = {}
        self.scaler = None
        self.protbert_model = None
        self.protbert_tokenizer = None
        self.device = None
        self.config = {}
        
        self._load_models()
    
    @classmethod
    def from_pretrained(cls, model_name_or_path: str):
        """Load model from Hugging Face or local path"""
        return cls(model_name_or_path)
    
    def _load_models(self):
        """Load all model components"""
        
        # Load config
        config_path = os.path.join(self.model_dir, "config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        
        # Load scaler
        scaler_path = os.path.join(self.model_dir, "scaler.pkl")
        if os.path.exists(scaler_path):
            self.scaler = joblib.load(scaler_path)
        
        # Load ensemble models
        models_dir = os.path.join(self.model_dir, "models")
        if os.path.exists(models_dir):
            for model_file in os.listdir(models_dir):
                if model_file.endswith('.pkl'):
                    model_name = model_file.replace('.pkl', '').replace('_', ' ')
                    model_path = os.path.join(models_dir, model_file)
                    self.models[model_name] = joblib.load(model_path)
        
        # Load ProtBERT
        self._load_protbert()
    
    def _load_protbert(self):
        """Load ProtBERT for feature extraction"""
        try:
            self.protbert_tokenizer = BertTokenizer.from_pretrained(
                "Rostlab/prot_bert", do_lower_case=False
            )
            self.protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.protbert_model = self.protbert_model.to(self.device)
            self.protbert_model.eval()
        except Exception as e:
            raise RuntimeError(f"Failed to load ProtBERT: {e}")
    
    def _extract_features(self, sequences: List[str]) -> np.ndarray:
        """Extract ProtBERT features from sequences"""
        
        n_length = self.config.get('n_terminal_length', 50)
        
        # Prepare N-terminal sequences
        n_terminal_seqs = []
        for seq in sequences:
            # Clean sequence
            clean_seq = re.sub(r"[^ACDEFGHIKLMNPQRSTVWXY]", "X", seq)
            
            # Extract N-terminal
            if len(clean_seq) >= n_length:
                n_terminal_seqs.append(clean_seq[:n_length])
            else:
                padded = clean_seq + 'X' * (n_length - len(clean_seq))
                n_terminal_seqs.append(padded)
        
        # Extract embeddings
        all_embeddings = []
        batch_size = 8
        
        for i in range(0, len(n_terminal_seqs), batch_size):
            batch = n_terminal_seqs[i:i+batch_size]
            
            # Preprocess for ProtBERT
            prepped = [" ".join(list(seq)) for seq in batch]
            
            # Tokenize
            inputs = self.protbert_tokenizer.batch_encode_plus(
                prepped,
                add_special_tokens=True,
                padding="longest",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            
            # Generate embeddings
            with torch.no_grad():
                outputs = self.protbert_model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask
                )
            
            # Extract per-protein embeddings
            for j, seq in enumerate(batch):
                seq_length = min(len(seq), 510)
                if seq_length > 0:
                    seq_emb = outputs.last_hidden_state[j, 1:seq_length+1]
                    protein_emb = seq_emb.mean(dim=0).cpu().numpy()
                else:
                    protein_emb = outputs.last_hidden_state[j, 0].cpu().numpy()
                
                all_embeddings.append(protein_emb)
        
        return np.array(all_embeddings)
    
    def predict(self, sequence: str) -> Dict[str, Any]:
        """Predict signal peptide for a single sequence"""
        
        # Extract features
        X = self._extract_features([sequence])
        
        # Apply scaling
        if self.scaler:
            X = self.scaler.transform(X)
        
        # Get predictions from all models
        predictions = []
        probabilities = []
        
        for model_name, model in self.models.items():
            try:
                pred = model.predict(X)[0]
                prob = model.predict_proba(X)[0, 1]
                predictions.append(pred)
                probabilities.append(prob)
            except Exception:
                continue
        
        # Ensemble decision
        if probabilities:
            ensemble_prob = np.mean(probabilities)
            ensemble_pred = ensemble_prob > 0.5
            
            # Confidence assessment
            if ensemble_prob > 0.8 or ensemble_prob < 0.2:
                confidence = "High"
            elif ensemble_prob > 0.7 or ensemble_prob < 0.3:
                confidence = "Medium"
            else:
                confidence = "Low"
        else:
            ensemble_prob = 0.0
            ensemble_pred = False
            confidence = "Error"
        
        return {
            "has_signal_peptide": bool(ensemble_pred),
            "probability": float(ensemble_prob),
            "confidence": confidence,
            "sequence_length": len(sequence)
        }
    
    def predict_batch(self, sequences: Union[List[str], Dict[str, str]]) -> Dict[str, Dict[str, Any]]:
        """Predict signal peptides for multiple sequences"""
        
        if isinstance(sequences, list):
            sequences = {f"seq_{i}": seq for i, seq in enumerate(sequences)}
        
        results = {}
        for name, sequence in sequences.items():
            results[name] = self.predict(sequence)
        
        return results

def pipeline(sequences: Union[str, List[str], Dict[str, str]], **kwargs) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]:
    """Hugging Face pipeline interface"""
    
    # This would be automatically set by HF
    model_dir = kwargs.get('model_dir', '.')
    
    predictor = SignalSeekerPredictor(model_dir)
    
    if isinstance(sequences, str):
        return predictor.predict(sequences)
    else:
        return predictor.predict_batch(sequences)