File size: 27,868 Bytes
da1403d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f677010
 
da1403d
 
 
f677010
 
 
 
 
da1403d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f677010
da1403d
 
 
f677010
 
 
 
 
da1403d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Union
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForCausalLM,
    PretrainedConfig, PreTrainedModel, GenerationMixin
)
from transformers.models.auto import CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING
from peft import PeftModel, LoraConfig, get_peft_model

EXPERTS_LIST = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
]


class MoLAConfig(PretrainedConfig):
    """Configuration class for MoLA-LM model."""
    
    model_type = "mola_lm"
    
    def __init__(
        self,
        base_model_name_or_path: str = "Qwen/Qwen3-4B-Thinking-2507",
        task_labels: List[str] = None,
        router_config: Dict = None,
        lora_configs: Dict[str, Dict] = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name_or_path = base_model_name_or_path
        self.task_labels = task_labels or EXPERTS_LIST
        self.router_config = router_config or {}
        self.lora_configs = lora_configs or {}
        self.num_loras = len(self.task_labels)


class MoLAForCausalLM(PreTrainedModel, GenerationMixin):
    """
    MoLA Language Model for Causal Language Modeling - AutoModel Compatible
    """
    
    config_class = MoLAConfig
    base_model_prefix = "mola_model"  # Avoid recursion by using unique prefix
    supports_gradient_checkpointing = True
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Store model path for loading resources  
        self.model_path = getattr(config, '_name_or_path', None)
        
        # Load base model (use base_model_prefix name)
        print(f"Loading base model: {self.config.base_model_name_or_path}")
        self.mola_model = AutoModelForCausalLM.from_pretrained(
            self.config.base_model_name_or_path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        
        # Load tokenizer
        if self.model_path:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_name_or_path)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize router
        self._init_router()
        
        # Initialize current model state (will be updated by _load_lora_adapters)
        self._current_lora = None
        self._current_adapted_model = self.mola_model
        
        # Load LoRA configurations and adapters (this will update _current_adapted_model)
        self._load_lora_adapters()
        
        # Initialize device property (needed for PreTrainedModel compatibility)
        self._device = next(self.mola_model.parameters()).device
        
        # Load router weights if available
        self._load_router_weights()
        
        print("MoLA-LM initialized successfully!")
    
    def _load_router_weights(self):
        """Load router weights from the saved checkpoint."""
        if self.model_path:
            try:
                # Handle both local and Hub paths for router weights
                if os.path.exists(self.model_path):
                    # Local path
                    router_weights_path = os.path.join(self.model_path, "router_weights.pth")
                    if os.path.exists(router_weights_path):
                        checkpoint = torch.load(router_weights_path, map_location='cpu')
                    else:
                        print("⚠️ No router weights found locally")
                        return
                else:
                    # Hub path - download router weights
                    try:
                        from huggingface_hub import hf_hub_download
                        router_weights_path = hf_hub_download(
                            repo_id=self.model_path,
                            filename="router_weights.pth",
                            local_files_only=False
                        )
                        checkpoint = torch.load(router_weights_path, map_location='cpu')
                        print("📥 Downloaded router weights from Hub")
                    except Exception as hub_e:
                        print(f"⚠️ Failed to download router weights from Hub: {hub_e}")
                        print("🔄 Router will use random initialization (reduced performance)")
                        return
                
                # Load router decoder weights
                router_state_dict = {}
                for key, value in checkpoint.items():
                    if not key.startswith('encoder.'):  # Skip encoder weights
                        router_state_dict[key] = value
                
                if router_state_dict:
                    self.router_decoder.load_state_dict(router_state_dict, strict=False)
                    print("✅ Loaded router weights successfully!")
                    
                    # Verify weights loaded by checking if they're not all zeros
                    first_layer = next(iter(self.router_decoder.parameters()))
                    if torch.all(first_layer == 0):
                        print("⚠️ Warning: Router weights appear to be zero-initialized")
                    else:
                        print("🎯 Router weights verified - non-zero values detected")
                else:
                    print("⚠️ No valid router weights found in checkpoint")
                    
            except Exception as e:
                print(f"❌ Failed to load router weights: {e}")
                print("🔄 Router will use random initialization (reduced performance)")
    
    def _init_router(self):
        """Initialize the router model for LoRA selection."""
        try:
            from transformers import AutoModel
            
            print("Initializing router components...")
            # Router components
            self.router_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
            self.router_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
            
            # Freeze encoder
            for param in self.router_encoder.parameters():
                param.requires_grad = False
            
            # Router decoder
            encoder_dim = self.router_encoder.config.hidden_size
            self.router_decoder = nn.Sequential(
                nn.Linear(encoder_dim, 768),
                nn.ReLU(),
                nn.Dropout(.3),
                nn.Linear(768, 768),
                nn.ReLU(),
                nn.Dropout(.3),
                nn.Linear(768, 768),
                nn.ReLU(),
                nn.Dropout(.3),
                nn.Linear(768, 384),
                nn.ReLU(),
                nn.Dropout(.3),
                nn.Linear(384, 192),
                nn.ReLU(),
                nn.Dropout(.3),
                nn.Linear(192, self.config.num_loras)
            )
            
            # Move router to device
            if torch.cuda.is_available():
                self.router_encoder = self.router_encoder.cuda()
                self.router_decoder = self.router_decoder.cuda()
            
            print("Router initialized successfully!")
            
        except ImportError as e:
            raise ImportError(f"Required dependencies not found: {e}")
    
    def _load_lora_adapters(self):
        """Load LoRA adapters using PEFT (single wrapper, multiple adapters)."""
        from huggingface_hub import hf_hub_download
        
        if not self.model_path:
            print("No model path specified, skipping LoRA loading")
            return
            
        print("Loading LoRA adapters (single wrapper)...")
        
        # Get the first adapter to create the initial PEFT wrapper
        first_adapter = str(self.config.task_labels[0])
        first_lora_path = None
        
        try:
            # Handle both local and Hub paths for first adapter
            if os.path.exists(self.model_path):
                # Local path
                first_lora_path = os.path.join(self.model_path, "loras", first_adapter)
                if not os.path.exists(first_lora_path):
                    raise FileNotFoundError(f"First adapter directory not found: {first_lora_path}")
            else:
                # Hub path - download first adapter
                try:
                    # Download both required files for first adapter
                    adapter_weights_file = hf_hub_download(
                        repo_id=self.model_path, 
                        filename=f"loras/{first_adapter}/adapter_model.safetensors"
                    )
                    adapter_config_file = hf_hub_download(
                        repo_id=self.model_path, 
                        filename=f"loras/{first_adapter}/adapter_config.json"
                    )
                    first_lora_path = os.path.dirname(adapter_weights_file)
                    print(f"Downloaded first adapter to: {first_lora_path}")
                except Exception as e:
                    raise Exception(f"Failed to download first adapter {first_adapter}: {e}")
            
            # Create the initial PEFT wrapper WITHOUT specifying adapter_name to use default
            peft_model = PeftModel.from_pretrained(
                self.mola_model, 
                first_lora_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            print(f"✅ Loaded first LoRA: {first_adapter} (as default)")
            
            # Load remaining adapters into the same wrapper with unique names
            for task_name in self.config.task_labels[1:]:
                try:
                    lora_path = None
                    
                    if os.path.exists(self.model_path):
                        # Local path
                        lora_path = os.path.join(self.model_path, "loras", task_name)
                        if not os.path.exists(lora_path):
                            print(f"⚠️ LoRA directory not found: {lora_path}")
                            continue
                    else:
                        # Hub path - download adapter
                        try:
                            adapter_weights_file = hf_hub_download(
                                repo_id=self.model_path, 
                                filename=f"loras/{task_name}/adapter_model.safetensors"
                            )
                            adapter_config_file = hf_hub_download(
                                repo_id=self.model_path, 
                                filename=f"loras/{task_name}/adapter_config.json"
                            )
                            lora_path = os.path.dirname(adapter_weights_file)
                        except Exception as e:
                            print(f"❌ Failed to download LoRA {task_name}: {e}")
                            continue
                    
                    # Load adapter into the same PEFT model with unique name
                    peft_model.load_adapter(lora_path, adapter_name=task_name)
                    print(f"✅ Loaded LoRA: {task_name}")
                    
                except Exception as e:
                    print(f"❌ Failed to load LoRA {task_name}: {e}")
            
            # Store single PEFT model for all adapters
            self.lora_models = {str(name): peft_model for name in self.config.task_labels}
            self._current_lora = first_adapter
            self._current_adapted_model = peft_model
            
            print(f"Loaded {len(self.config.task_labels)} LoRA adapters into one PEFT model.")
            print(f"Available adapter names: {list(peft_model.peft_config.keys())}")
            
        except Exception as e:
            print(f"❌ Failed to initialize LoRA loading: {e}")
            self.lora_models = {}
            self._current_adapted_model = self.mola_model
            self._current_lora = None
    
    def predict_best_lora(self, text: str) -> str:
        """Predict the best LoRA adapter for given text."""
        # Set models to eval mode
        self.router_encoder.eval()
        self.router_decoder.eval()
        
        # Encode text
        inputs = self.router_tokenizer(
            [text],
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        # Move to device
        device = next(self.router_decoder.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.router_encoder(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            logits = self.router_decoder(embeddings)
        
        # Get best LoRA
        best_idx = torch.argmax(logits, dim=-1).item()
        predicted_label = self.config.task_labels[best_idx]
        
        # Debug output
        # print(f"Debug - Text: {text[:50]}...")
        # print(f"Debug - Logits: {logits[0].cpu().numpy()}")
        # print(f"Debug - Best idx: {best_idx}, Label: {predicted_label}")
        
        return predicted_label
    
    def _apply_lora(self, lora_name: str):
        """Apply the selected LoRA adapter using set_adapter."""
        if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel):
            # Map task labels to actual adapter names in PEFT model
            # First adapter (task_labels[0]) is loaded as 'default', others keep their names
            first_adapter = str(self.config.task_labels[0])
            if str(lora_name) == first_adapter:
                actual_adapter_name = "default"
            else:
                actual_adapter_name = str(lora_name)
            
            # Check if the adapter exists in the PEFT model
            if actual_adapter_name in self._current_adapted_model.peft_config:
                if lora_name != self._current_lora:
                    self._current_adapted_model.set_adapter(actual_adapter_name)
                    self._current_lora = str(lora_name)
                    # print(f"🎯 Applied LoRA: {lora_name} (as {actual_adapter_name})")  # Uncomment for debugging
            else:
                print(f"⚠️ LoRA adapter '{lora_name}' (mapped to '{actual_adapter_name}') not found in PEFT model. Available: {list(self._current_adapted_model.peft_config.keys())}")
                # Keep current adapter if requested one doesn't exist
        else:
            # Fallback to base model if no PEFT model available
            self._current_adapted_model = self.mola_model
            self._current_lora = None
            print(f"⚠️ No PEFT model available, using base model")
    
    def get_available_loras(self) -> List[str]:
        """Get list of available LoRA adapter names."""
        if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel):
            return list(self._current_adapted_model.peft_config.keys())
        else:
            return []
    
    def test_adapter_uniqueness(self, layer_name: str = "base_model.model.model.layers.33.mlp.down_proj"):
        """
        Regression test to verify that adapters have different weights.
        
        Args:
            layer_name: The layer to test (default is a common MLP layer)
        
        Returns:
            Dict[str, str]: Mapping of adapter names to their weight hashes
        """
        import hashlib
        
        if not hasattr(self, '_current_adapted_model') or not isinstance(self._current_adapted_model, PeftModel):
            print("⚠️ No PEFT model available for testing")
            return {}
        
        names = self.get_available_loras()
        if len(names) <= 1:
            print(f"⚠️ Need at least 2 adapters for uniqueness test, found {len(names)}")
            return {}
        
        def fused_sha(adapter_name, layer_name):
            """Compute SHA256 hash of fused LoRA weights for given adapter and layer."""
            # Switch to the adapter
            self._apply_lora(adapter_name)
            
            # Navigate to the specified layer
            try:
                mod = self._current_adapted_model
                for part in layer_name.split("."):
                    if part:
                        mod = getattr(mod, part)
                
                # Get LoRA components
                if not hasattr(mod, 'lora_A') or not hasattr(mod, 'lora_B'):
                    print(f"⚠️ Layer {layer_name} doesn't have LoRA components")
                    return "no_lora"
                
                # Get the currently active adapter key from the PEFT model
                active_adapter_key = self._current_adapted_model.active_adapter
                
                # Check if the active adapter key exists
                if active_adapter_key not in mod.lora_A:
                    print(f"⚠️ Active adapter key '{active_adapter_key}' not found. Available: {list(mod.lora_A.keys())}")
                    return f"missing_{adapter_name}"
                
                A = mod.lora_A[active_adapter_key].weight
                B = mod.lora_B[active_adapter_key].weight
                s = float(mod.scaling[active_adapter_key])
                
                # Compute fused weights: ΔW = (B @ A) * scaling
                dW = (B @ A) * s
                
                # Convert to bytes and hash
                tensor_bytes = dW.detach().to("cpu", dtype=torch.float32).contiguous().numpy().tobytes()
                return hashlib.sha256(tensor_bytes).hexdigest()[:16]
                
            except Exception as e:
                print(f"❌ Error computing hash for {adapter_name}: {e}")
                return f"error_{adapter_name}"
        
        print(f"🧪 Testing adapter uniqueness on layer: {layer_name}")
        hashes = {}
        for adapter_name in names:
            hash_val = fused_sha(adapter_name, layer_name)
            hashes[adapter_name] = hash_val
            print(f"  {adapter_name}: {hash_val}")
        
        # Check uniqueness
        unique_hashes = set(hashes.values())
        if len(unique_hashes) == len(names):
            print("✅ All adapters have unique weights!")
        else:
            print(f"❌ Found duplicate weights! {len(names)} adapters but only {len(unique_hashes)} unique hashes")
            # Show which ones are identical
            from collections import defaultdict
            hash_to_adapters = defaultdict(list)
            for adapter, hash_val in hashes.items():
                hash_to_adapters[hash_val].append(adapter)
            
            for hash_val, adapter_list in hash_to_adapters.items():
                if len(adapter_list) > 1:
                    print(f"  Identical weights (hash {hash_val}): {adapter_list}")
        
        return hashes
    
    def generate(self, input_ids=None, attention_mask=None, **kwargs):
        """
        Standard generate method with automatic LoRA selection.
        Works exactly like any other LLM's generate method.
        """
        # If we have input_ids, predict and apply the best LoRA
        if input_ids is not None and hasattr(self, 'tokenizer'):
            try:
                # Decode the input to get the text for LoRA prediction
                if len(input_ids.shape) > 1:
                    # Batch input - use first item
                    text_input = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
                else:
                    text_input = self.tokenizer.decode(input_ids, skip_special_tokens=True)
                
                # Clean the text thoroughly to remove ALL chat template artifacts
                import re
                
                # Build regex patterns to avoid escape issues in embedded code
                start_pattern = '<' + '|im_start|' + '>user'
                end_pattern = '<' + '|im_end|' + '>'
                
                # First, try to extract just the user's actual question/prompt
                if start_pattern in text_input and end_pattern in text_input:
                    start_idx = text_input.find(start_pattern) + len(start_pattern)
                    end_idx = text_input.find(end_pattern, start_idx)
                    if end_idx > start_idx:
                        text_input = text_input[start_idx:end_idx].strip()
                
                # Clean up any remaining template artifacts
                # Remove special tokens with simple string replacement
                text_input = text_input.replace('<|im_start|>', '')
                text_input = text_input.replace('<|im_end|>', '')
                text_input = text_input.replace('system', '')
                text_input = text_input.replace('user', '')  
                text_input = text_input.replace('assistant', '')
                
                # Remove system message patterns
                if 'You are Qwen' in text_input:
                    lines = text_input.split('\n')
                    lines = [line for line in lines if 'You are' not in line and 'Alibaba' not in line]
                    text_input = ' '.join(lines)
                
                # Final cleanup
                text_input = re.sub(r'\n+', ' ', text_input)  # Replace newlines with spaces
                text_input = re.sub(r'\s+', ' ', text_input)  # Normalize whitespace  
                text_input = text_input.strip()
                
                # Debug: print the actual text being classified
                # print(f"DEBUG RAW: '{self.tokenizer.decode(input_ids[0], skip_special_tokens=False)}'")
                # print(f"DEBUG CLEAN: '{text_input}'")
                
                # Predict and apply best LoRA
                best_lora = self.predict_best_lora(text_input)
                self._apply_lora(best_lora)
                
            except Exception as e:
                # If LoRA prediction fails, use base model
                # print(f"DEBUG: LoRA prediction failed: {e}")
                self._current_adapted_model = self.mola_model
                self._current_lora = None
        
        # Use the currently adapted model for generation
        return self._current_adapted_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def forward(self, input_ids, attention_mask=None, **kwargs):
        """Forward pass through the model."""
        return self._current_adapted_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
    
    def __call__(self, *args, **kwargs):
        """Make the model callable."""
        return self._current_adapted_model(*args, **kwargs)
    
    def get_input_embeddings(self):
        """Get the input embeddings."""
        return self._current_adapted_model.get_input_embeddings()
    
    def set_input_embeddings(self, value):
        """Set the input embeddings."""
        self._current_adapted_model.set_input_embeddings(value)
        # Also set for base model to keep them in sync
        self.mola_model.set_input_embeddings(value)
    
    def get_output_embeddings(self):
        """Get the output embeddings.""" 
        return self._current_adapted_model.get_output_embeddings()
    
    def set_output_embeddings(self, value):
        """Set the output embeddings."""
        self._current_adapted_model.set_output_embeddings(value)
        # Also set for base model to keep them in sync
        self.mola_model.set_output_embeddings(value)
    
    def tie_weights(self):
        """Tie input and output embeddings."""
        self._current_adapted_model.tie_weights()
        
    def resize_token_embeddings(self, new_num_tokens):
        """Resize token embeddings."""
        return self._current_adapted_model.resize_token_embeddings(new_num_tokens)
    
    @property
    def device(self):
        """Get the device of the model."""
        return next(self.mola_model.parameters()).device
    
    @classmethod  
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """Load model from pretrained path (transformers compatibility)."""
        # Load config
        config = MoLAConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        # Store the path for resource loading
        config._name_or_path = pretrained_model_name_or_path
        return cls(config)
    
    def save_pretrained(self, save_directory, **kwargs):
        """Save model using standard transformers approach."""
        # Accept standard transformers parameters but use the ones we need
        max_shard_size = kwargs.get('max_shard_size', "5GB")
        safe_serialization = kwargs.get('safe_serialization', True)
        
        os.makedirs(save_directory, exist_ok=True)
        
        # Save config using transformers method
        self.config.save_pretrained(save_directory)
        
        # Save tokenizer if available
        if hasattr(self, 'tokenizer'):
            self.tokenizer.save_pretrained(save_directory)
        
        # Save the base model with proper sharding if needed
        try:
            # Use the base model's save_pretrained with the parameters
            self.mola_model.save_pretrained(
                save_directory, 
                max_shard_size=max_shard_size,
                safe_serialization=safe_serialization
            )
        except Exception as e:
            print(f"Warning: Could not save base model weights: {e}")
            # Fallback: just save the config and tokenizer
            pass
        
        # Save router weights if they exist
        try:
            if hasattr(self, 'router_decoder'):
                router_state_dict = self.router_decoder.state_dict()
                torch.save(router_state_dict, os.path.join(save_directory, "router_weights.pth"))
        except Exception as e:
            print(f"Warning: Could not save router weights: {e}")
        
        print(f"Model saved to {save_directory}")
    
    def get_current_lora(self) -> str:
        """Get the currently applied LoRA adapter name."""
        return self._current_lora or "base_model"
    
    def get_available_loras(self) -> List[str]:
        """Get list of available LoRA adapters."""
        return list(self.lora_models.keys())


# For transformers AutoModel registration
def _load_mola_model(model_path, **kwargs):
    """Helper function to load MoLA model."""
    return MoLAForCausalLM.from_pretrained(model_path, **kwargs)


# Register with transformers AutoModel system
try:
    CONFIG_MAPPING.register("mola_lm", MoLAConfig)
    MODEL_FOR_CAUSAL_LM_MAPPING.register(MoLAConfig, MoLAForCausalLM)
    print("✅ Successfully registered MoLA-LM with AutoModel!")
except Exception as e:
    print(f"⚠️ AutoModel registration failed: {e}")
    # Try alternative registration for backwards compatibility
    try:
        from transformers import AutoConfig, AutoModelForCausalLM
        AutoConfig.register("mola_lm", MoLAConfig)
        AutoModelForCausalLM.register(MoLAConfig, MoLAForCausalLM)
        print("✅ Successfully registered MoLA-LM with legacy method!")
    except Exception as e2:
        print(f"⚠️ Legacy registration also failed: {e2}")
        print("Model can still be loaded directly with MoLAForCausalLM.from_pretrained()")