import os import json import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Union from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel, GenerationMixin ) from transformers.models.auto import CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING from peft import PeftModel, LoraConfig, get_peft_model EXPERTS_LIST = [ "0", "1", "2", "3", "4", "5", "6", "7", "8", ] class MoLAConfig(PretrainedConfig): """Configuration class for MoLA-LM model.""" model_type = "mola_lm" def __init__( self, base_model_name_or_path: str = "Qwen/Qwen3-4B-Thinking-2507", task_labels: List[str] = None, router_config: Dict = None, lora_configs: Dict[str, Dict] = None, **kwargs ): super().__init__(**kwargs) self.base_model_name_or_path = base_model_name_or_path self.task_labels = task_labels or EXPERTS_LIST self.router_config = router_config or {} self.lora_configs = lora_configs or {} self.num_loras = len(self.task_labels) class MoLAForCausalLM(PreTrainedModel, GenerationMixin): """ MoLA Language Model for Causal Language Modeling - AutoModel Compatible """ config_class = MoLAConfig base_model_prefix = "mola_model" # Avoid recursion by using unique prefix supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.config = config # Store model path for loading resources self.model_path = getattr(config, '_name_or_path', None) # Load base model (use base_model_prefix name) print(f"Loading base model: {self.config.base_model_name_or_path}") self.mola_model = AutoModelForCausalLM.from_pretrained( self.config.base_model_name_or_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Load tokenizer if self.model_path: self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) else: self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_name_or_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Initialize router self._init_router() # Initialize current model state (will be updated by _load_lora_adapters) self._current_lora = None self._current_adapted_model = self.mola_model # Load LoRA configurations and adapters (this will update _current_adapted_model) self._load_lora_adapters() # Initialize device property (needed for PreTrainedModel compatibility) self._device = next(self.mola_model.parameters()).device # Load router weights if available self._load_router_weights() print("MoLA-LM initialized successfully!") def _load_router_weights(self): """Load router weights from the saved checkpoint.""" if self.model_path: try: # Handle both local and Hub paths for router weights if os.path.exists(self.model_path): # Local path router_weights_path = os.path.join(self.model_path, "router_weights.pth") if os.path.exists(router_weights_path): checkpoint = torch.load(router_weights_path, map_location='cpu') else: print("โš ๏ธ No router weights found locally") return else: # Hub path - download router weights try: from huggingface_hub import hf_hub_download router_weights_path = hf_hub_download( repo_id=self.model_path, filename="router_weights.pth", local_files_only=False ) checkpoint = torch.load(router_weights_path, map_location='cpu') print("๐Ÿ“ฅ Downloaded router weights from Hub") except Exception as hub_e: print(f"โš ๏ธ Failed to download router weights from Hub: {hub_e}") print("๐Ÿ”„ Router will use random initialization (reduced performance)") return # Load router decoder weights router_state_dict = {} for key, value in checkpoint.items(): if not key.startswith('encoder.'): # Skip encoder weights router_state_dict[key] = value if router_state_dict: self.router_decoder.load_state_dict(router_state_dict, strict=False) print("โœ… Loaded router weights successfully!") # Verify weights loaded by checking if they're not all zeros first_layer = next(iter(self.router_decoder.parameters())) if torch.all(first_layer == 0): print("โš ๏ธ Warning: Router weights appear to be zero-initialized") else: print("๐ŸŽฏ Router weights verified - non-zero values detected") else: print("โš ๏ธ No valid router weights found in checkpoint") except Exception as e: print(f"โŒ Failed to load router weights: {e}") print("๐Ÿ”„ Router will use random initialization (reduced performance)") def _init_router(self): """Initialize the router model for LoRA selection.""" try: from transformers import AutoModel print("Initializing router components...") # Router components self.router_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") self.router_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") # Freeze encoder for param in self.router_encoder.parameters(): param.requires_grad = False # Router decoder encoder_dim = self.router_encoder.config.hidden_size self.router_decoder = nn.Sequential( nn.Linear(encoder_dim, 768), nn.ReLU(), nn.Dropout(.3), nn.Linear(768, 768), nn.ReLU(), nn.Dropout(.3), nn.Linear(768, 768), nn.ReLU(), nn.Dropout(.3), nn.Linear(768, 384), nn.ReLU(), nn.Dropout(.3), nn.Linear(384, 192), nn.ReLU(), nn.Dropout(.3), nn.Linear(192, self.config.num_loras) ) # Move router to device if torch.cuda.is_available(): self.router_encoder = self.router_encoder.cuda() self.router_decoder = self.router_decoder.cuda() print("Router initialized successfully!") except ImportError as e: raise ImportError(f"Required dependencies not found: {e}") def _load_lora_adapters(self): """Load LoRA adapters using PEFT (single wrapper, multiple adapters).""" from huggingface_hub import hf_hub_download if not self.model_path: print("No model path specified, skipping LoRA loading") return print("Loading LoRA adapters (single wrapper)...") # Get the first adapter to create the initial PEFT wrapper first_adapter = str(self.config.task_labels[0]) first_lora_path = None try: # Handle both local and Hub paths for first adapter if os.path.exists(self.model_path): # Local path first_lora_path = os.path.join(self.model_path, "loras", first_adapter) if not os.path.exists(first_lora_path): raise FileNotFoundError(f"First adapter directory not found: {first_lora_path}") else: # Hub path - download first adapter try: # Download first adapter to get local path adapter_file = hf_hub_download( repo_id=self.model_path, filename=f"loras/{first_adapter}/adapter_model.safetensors" ) first_lora_path = os.path.dirname(adapter_file) print(f"Downloaded first adapter to: {first_lora_path}") except Exception as e: raise Exception(f"Failed to download first adapter {first_adapter}: {e}") # Create the initial PEFT wrapper WITHOUT specifying adapter_name to use default peft_model = PeftModel.from_pretrained( self.mola_model, first_lora_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print(f"โœ… Loaded first LoRA: {first_adapter} (as default)") # Load remaining adapters into the same wrapper with unique names for task_name in self.config.task_labels[1:]: try: lora_path = None if os.path.exists(self.model_path): # Local path lora_path = os.path.join(self.model_path, "loras", task_name) if not os.path.exists(lora_path): print(f"โš ๏ธ LoRA directory not found: {lora_path}") continue else: # Hub path - download adapter try: adapter_file = hf_hub_download( repo_id=self.model_path, filename=f"loras/{task_name}/adapter_model.safetensors" ) lora_path = os.path.dirname(adapter_file) except Exception as e: print(f"โŒ Failed to download LoRA {task_name}: {e}") continue # Load adapter into the same PEFT model with unique name peft_model.load_adapter(lora_path, adapter_name=task_name) print(f"โœ… Loaded LoRA: {task_name}") except Exception as e: print(f"โŒ Failed to load LoRA {task_name}: {e}") # Store single PEFT model for all adapters self.lora_models = {str(name): peft_model for name in self.config.task_labels} self._current_lora = first_adapter self._current_adapted_model = peft_model print(f"Loaded {len(self.config.task_labels)} LoRA adapters into one PEFT model.") print(f"Available adapter names: {list(peft_model.peft_config.keys())}") except Exception as e: print(f"โŒ Failed to initialize LoRA loading: {e}") self.lora_models = {} self._current_adapted_model = self.mola_model self._current_lora = None def predict_best_lora(self, text: str) -> str: """Predict the best LoRA adapter for given text.""" # Set models to eval mode self.router_encoder.eval() self.router_decoder.eval() # Encode text inputs = self.router_tokenizer( [text], padding=True, truncation=True, max_length=512, return_tensors="pt" ) # Move to device device = next(self.router_decoder.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.router_encoder(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) logits = self.router_decoder(embeddings) # Get best LoRA best_idx = torch.argmax(logits, dim=-1).item() predicted_label = self.config.task_labels[best_idx] # Debug output # print(f"Debug - Text: {text[:50]}...") # print(f"Debug - Logits: {logits[0].cpu().numpy()}") # print(f"Debug - Best idx: {best_idx}, Label: {predicted_label}") return predicted_label def _apply_lora(self, lora_name: str): """Apply the selected LoRA adapter using set_adapter.""" if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel): # Map task labels to actual adapter names in PEFT model # First adapter (task_labels[0]) is loaded as 'default', others keep their names first_adapter = str(self.config.task_labels[0]) if str(lora_name) == first_adapter: actual_adapter_name = "default" else: actual_adapter_name = str(lora_name) # Check if the adapter exists in the PEFT model if actual_adapter_name in self._current_adapted_model.peft_config: if lora_name != self._current_lora: self._current_adapted_model.set_adapter(actual_adapter_name) self._current_lora = str(lora_name) # print(f"๐ŸŽฏ Applied LoRA: {lora_name} (as {actual_adapter_name})") # Uncomment for debugging else: print(f"โš ๏ธ LoRA adapter '{lora_name}' (mapped to '{actual_adapter_name}') not found in PEFT model. Available: {list(self._current_adapted_model.peft_config.keys())}") # Keep current adapter if requested one doesn't exist else: # Fallback to base model if no PEFT model available self._current_adapted_model = self.mola_model self._current_lora = None print(f"โš ๏ธ No PEFT model available, using base model") def get_available_loras(self) -> List[str]: """Get list of available LoRA adapter names.""" if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel): return list(self._current_adapted_model.peft_config.keys()) else: return [] def test_adapter_uniqueness(self, layer_name: str = "base_model.model.model.layers.33.mlp.down_proj"): """ Regression test to verify that adapters have different weights. Args: layer_name: The layer to test (default is a common MLP layer) Returns: Dict[str, str]: Mapping of adapter names to their weight hashes """ import hashlib if not hasattr(self, '_current_adapted_model') or not isinstance(self._current_adapted_model, PeftModel): print("โš ๏ธ No PEFT model available for testing") return {} names = self.get_available_loras() if len(names) <= 1: print(f"โš ๏ธ Need at least 2 adapters for uniqueness test, found {len(names)}") return {} def fused_sha(adapter_name, layer_name): """Compute SHA256 hash of fused LoRA weights for given adapter and layer.""" # Switch to the adapter self._apply_lora(adapter_name) # Navigate to the specified layer try: mod = self._current_adapted_model for part in layer_name.split("."): if part: mod = getattr(mod, part) # Get LoRA components if not hasattr(mod, 'lora_A') or not hasattr(mod, 'lora_B'): print(f"โš ๏ธ Layer {layer_name} doesn't have LoRA components") return "no_lora" # Get the currently active adapter key from the PEFT model active_adapter_key = self._current_adapted_model.active_adapter # Check if the active adapter key exists if active_adapter_key not in mod.lora_A: print(f"โš ๏ธ Active adapter key '{active_adapter_key}' not found. Available: {list(mod.lora_A.keys())}") return f"missing_{adapter_name}" A = mod.lora_A[active_adapter_key].weight B = mod.lora_B[active_adapter_key].weight s = float(mod.scaling[active_adapter_key]) # Compute fused weights: ฮ”W = (B @ A) * scaling dW = (B @ A) * s # Convert to bytes and hash tensor_bytes = dW.detach().to("cpu", dtype=torch.float32).contiguous().numpy().tobytes() return hashlib.sha256(tensor_bytes).hexdigest()[:16] except Exception as e: print(f"โŒ Error computing hash for {adapter_name}: {e}") return f"error_{adapter_name}" print(f"๐Ÿงช Testing adapter uniqueness on layer: {layer_name}") hashes = {} for adapter_name in names: hash_val = fused_sha(adapter_name, layer_name) hashes[adapter_name] = hash_val print(f" {adapter_name}: {hash_val}") # Check uniqueness unique_hashes = set(hashes.values()) if len(unique_hashes) == len(names): print("โœ… All adapters have unique weights!") else: print(f"โŒ Found duplicate weights! {len(names)} adapters but only {len(unique_hashes)} unique hashes") # Show which ones are identical from collections import defaultdict hash_to_adapters = defaultdict(list) for adapter, hash_val in hashes.items(): hash_to_adapters[hash_val].append(adapter) for hash_val, adapter_list in hash_to_adapters.items(): if len(adapter_list) > 1: print(f" Identical weights (hash {hash_val}): {adapter_list}") return hashes def generate(self, input_ids=None, attention_mask=None, **kwargs): """ Standard generate method with automatic LoRA selection. Works exactly like any other LLM's generate method. """ # If we have input_ids, predict and apply the best LoRA if input_ids is not None and hasattr(self, 'tokenizer'): try: # Decode the input to get the text for LoRA prediction if len(input_ids.shape) > 1: # Batch input - use first item text_input = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) else: text_input = self.tokenizer.decode(input_ids, skip_special_tokens=True) # Clean the text thoroughly to remove ALL chat template artifacts import re # Build regex patterns to avoid escape issues in embedded code start_pattern = '<' + '|im_start|' + '>user' end_pattern = '<' + '|im_end|' + '>' # First, try to extract just the user's actual question/prompt if start_pattern in text_input and end_pattern in text_input: start_idx = text_input.find(start_pattern) + len(start_pattern) end_idx = text_input.find(end_pattern, start_idx) if end_idx > start_idx: text_input = text_input[start_idx:end_idx].strip() # Clean up any remaining template artifacts # Remove special tokens with simple string replacement text_input = text_input.replace('<|im_start|>', '') text_input = text_input.replace('<|im_end|>', '') text_input = text_input.replace('system', '') text_input = text_input.replace('user', '') text_input = text_input.replace('assistant', '') # Remove system message patterns if 'You are Qwen' in text_input: lines = text_input.split('\n') lines = [line for line in lines if 'You are' not in line and 'Alibaba' not in line] text_input = ' '.join(lines) # Final cleanup text_input = re.sub(r'\n+', ' ', text_input) # Replace newlines with spaces text_input = re.sub(r'\s+', ' ', text_input) # Normalize whitespace text_input = text_input.strip() # Debug: print the actual text being classified # print(f"DEBUG RAW: '{self.tokenizer.decode(input_ids[0], skip_special_tokens=False)}'") # print(f"DEBUG CLEAN: '{text_input}'") # Predict and apply best LoRA best_lora = self.predict_best_lora(text_input) self._apply_lora(best_lora) except Exception as e: # If LoRA prediction fails, use base model # print(f"DEBUG: LoRA prediction failed: {e}") self._current_adapted_model = self.mola_model self._current_lora = None # Use the currently adapted model for generation return self._current_adapted_model.generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) def forward(self, input_ids, attention_mask=None, **kwargs): """Forward pass through the model.""" return self._current_adapted_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def __call__(self, *args, **kwargs): """Make the model callable.""" return self._current_adapted_model(*args, **kwargs) def get_input_embeddings(self): """Get the input embeddings.""" return self._current_adapted_model.get_input_embeddings() def set_input_embeddings(self, value): """Set the input embeddings.""" self._current_adapted_model.set_input_embeddings(value) # Also set for base model to keep them in sync self.mola_model.set_input_embeddings(value) def get_output_embeddings(self): """Get the output embeddings.""" return self._current_adapted_model.get_output_embeddings() def set_output_embeddings(self, value): """Set the output embeddings.""" self._current_adapted_model.set_output_embeddings(value) # Also set for base model to keep them in sync self.mola_model.set_output_embeddings(value) def tie_weights(self): """Tie input and output embeddings.""" self._current_adapted_model.tie_weights() def resize_token_embeddings(self, new_num_tokens): """Resize token embeddings.""" return self._current_adapted_model.resize_token_embeddings(new_num_tokens) @property def device(self): """Get the device of the model.""" return next(self.mola_model.parameters()).device @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """Load model from pretrained path (transformers compatibility).""" # Load config config = MoLAConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Store the path for resource loading config._name_or_path = pretrained_model_name_or_path return cls(config) def save_pretrained(self, save_directory, **kwargs): """Save model using standard transformers approach.""" # Accept standard transformers parameters but use the ones we need max_shard_size = kwargs.get('max_shard_size', "5GB") safe_serialization = kwargs.get('safe_serialization', True) os.makedirs(save_directory, exist_ok=True) # Save config using transformers method self.config.save_pretrained(save_directory) # Save tokenizer if available if hasattr(self, 'tokenizer'): self.tokenizer.save_pretrained(save_directory) # Save the base model with proper sharding if needed try: # Use the base model's save_pretrained with the parameters self.mola_model.save_pretrained( save_directory, max_shard_size=max_shard_size, safe_serialization=safe_serialization ) except Exception as e: print(f"Warning: Could not save base model weights: {e}") # Fallback: just save the config and tokenizer pass # Save router weights if they exist try: if hasattr(self, 'router_decoder'): router_state_dict = self.router_decoder.state_dict() torch.save(router_state_dict, os.path.join(save_directory, "router_weights.pth")) except Exception as e: print(f"Warning: Could not save router weights: {e}") print(f"Model saved to {save_directory}") def get_current_lora(self) -> str: """Get the currently applied LoRA adapter name.""" return self._current_lora or "base_model" def get_available_loras(self) -> List[str]: """Get list of available LoRA adapters.""" return list(self.lora_models.keys()) # For transformers AutoModel registration def _load_mola_model(model_path, **kwargs): """Helper function to load MoLA model.""" return MoLAForCausalLM.from_pretrained(model_path, **kwargs) # Register with transformers AutoModel system try: CONFIG_MAPPING.register("mola_lm", MoLAConfig) MODEL_FOR_CAUSAL_LM_MAPPING.register(MoLAConfig, MoLAForCausalLM) print("โœ… Successfully registered MoLA-LM with AutoModel!") except Exception as e: print(f"โš ๏ธ AutoModel registration failed: {e}") # Try alternative registration for backwards compatibility try: from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("mola_lm", MoLAConfig) AutoModelForCausalLM.register(MoLAConfig, MoLAForCausalLM) print("โœ… Successfully registered MoLA-LM with legacy method!") except Exception as e2: print(f"โš ๏ธ Legacy registration also failed: {e2}") print("Model can still be loaded directly with MoLAForCausalLM.from_pretrained()")