MoLA-v0.6-9x4b / modeling_mola_lm.py
AtAndDev's picture
Upload MoLA-LM: Mixture of LoRA Adapters Language Model
da1403d verified
raw
history blame
27.4 kB
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Union
from transformers import (
AutoConfig, AutoTokenizer, AutoModelForCausalLM,
PretrainedConfig, PreTrainedModel, GenerationMixin
)
from transformers.models.auto import CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING
from peft import PeftModel, LoraConfig, get_peft_model
EXPERTS_LIST = [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
]
class MoLAConfig(PretrainedConfig):
"""Configuration class for MoLA-LM model."""
model_type = "mola_lm"
def __init__(
self,
base_model_name_or_path: str = "Qwen/Qwen3-4B-Thinking-2507",
task_labels: List[str] = None,
router_config: Dict = None,
lora_configs: Dict[str, Dict] = None,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name_or_path = base_model_name_or_path
self.task_labels = task_labels or EXPERTS_LIST
self.router_config = router_config or {}
self.lora_configs = lora_configs or {}
self.num_loras = len(self.task_labels)
class MoLAForCausalLM(PreTrainedModel, GenerationMixin):
"""
MoLA Language Model for Causal Language Modeling - AutoModel Compatible
"""
config_class = MoLAConfig
base_model_prefix = "mola_model" # Avoid recursion by using unique prefix
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Store model path for loading resources
self.model_path = getattr(config, '_name_or_path', None)
# Load base model (use base_model_prefix name)
print(f"Loading base model: {self.config.base_model_name_or_path}")
self.mola_model = AutoModelForCausalLM.from_pretrained(
self.config.base_model_name_or_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Load tokenizer
if self.model_path:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_name_or_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize router
self._init_router()
# Initialize current model state (will be updated by _load_lora_adapters)
self._current_lora = None
self._current_adapted_model = self.mola_model
# Load LoRA configurations and adapters (this will update _current_adapted_model)
self._load_lora_adapters()
# Initialize device property (needed for PreTrainedModel compatibility)
self._device = next(self.mola_model.parameters()).device
# Load router weights if available
self._load_router_weights()
print("MoLA-LM initialized successfully!")
def _load_router_weights(self):
"""Load router weights from the saved checkpoint."""
if self.model_path:
try:
# Handle both local and Hub paths for router weights
if os.path.exists(self.model_path):
# Local path
router_weights_path = os.path.join(self.model_path, "router_weights.pth")
if os.path.exists(router_weights_path):
checkpoint = torch.load(router_weights_path, map_location='cpu')
else:
print("⚠️ No router weights found locally")
return
else:
# Hub path - download router weights
try:
from huggingface_hub import hf_hub_download
router_weights_path = hf_hub_download(
repo_id=self.model_path,
filename="router_weights.pth",
local_files_only=False
)
checkpoint = torch.load(router_weights_path, map_location='cpu')
print("📥 Downloaded router weights from Hub")
except Exception as hub_e:
print(f"⚠️ Failed to download router weights from Hub: {hub_e}")
print("🔄 Router will use random initialization (reduced performance)")
return
# Load router decoder weights
router_state_dict = {}
for key, value in checkpoint.items():
if not key.startswith('encoder.'): # Skip encoder weights
router_state_dict[key] = value
if router_state_dict:
self.router_decoder.load_state_dict(router_state_dict, strict=False)
print("✅ Loaded router weights successfully!")
# Verify weights loaded by checking if they're not all zeros
first_layer = next(iter(self.router_decoder.parameters()))
if torch.all(first_layer == 0):
print("⚠️ Warning: Router weights appear to be zero-initialized")
else:
print("🎯 Router weights verified - non-zero values detected")
else:
print("⚠️ No valid router weights found in checkpoint")
except Exception as e:
print(f"❌ Failed to load router weights: {e}")
print("🔄 Router will use random initialization (reduced performance)")
def _init_router(self):
"""Initialize the router model for LoRA selection."""
try:
from transformers import AutoModel
print("Initializing router components...")
# Router components
self.router_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.router_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# Freeze encoder
for param in self.router_encoder.parameters():
param.requires_grad = False
# Router decoder
encoder_dim = self.router_encoder.config.hidden_size
self.router_decoder = nn.Sequential(
nn.Linear(encoder_dim, 768),
nn.ReLU(),
nn.Dropout(.3),
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(.3),
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(.3),
nn.Linear(768, 384),
nn.ReLU(),
nn.Dropout(.3),
nn.Linear(384, 192),
nn.ReLU(),
nn.Dropout(.3),
nn.Linear(192, self.config.num_loras)
)
# Move router to device
if torch.cuda.is_available():
self.router_encoder = self.router_encoder.cuda()
self.router_decoder = self.router_decoder.cuda()
print("Router initialized successfully!")
except ImportError as e:
raise ImportError(f"Required dependencies not found: {e}")
def _load_lora_adapters(self):
"""Load LoRA adapters using PEFT (single wrapper, multiple adapters)."""
from huggingface_hub import hf_hub_download
if not self.model_path:
print("No model path specified, skipping LoRA loading")
return
print("Loading LoRA adapters (single wrapper)...")
# Get the first adapter to create the initial PEFT wrapper
first_adapter = str(self.config.task_labels[0])
first_lora_path = None
try:
# Handle both local and Hub paths for first adapter
if os.path.exists(self.model_path):
# Local path
first_lora_path = os.path.join(self.model_path, "loras", first_adapter)
if not os.path.exists(first_lora_path):
raise FileNotFoundError(f"First adapter directory not found: {first_lora_path}")
else:
# Hub path - download first adapter
try:
# Download first adapter to get local path
adapter_file = hf_hub_download(
repo_id=self.model_path,
filename=f"loras/{first_adapter}/adapter_model.safetensors"
)
first_lora_path = os.path.dirname(adapter_file)
print(f"Downloaded first adapter to: {first_lora_path}")
except Exception as e:
raise Exception(f"Failed to download first adapter {first_adapter}: {e}")
# Create the initial PEFT wrapper WITHOUT specifying adapter_name to use default
peft_model = PeftModel.from_pretrained(
self.mola_model,
first_lora_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print(f"✅ Loaded first LoRA: {first_adapter} (as default)")
# Load remaining adapters into the same wrapper with unique names
for task_name in self.config.task_labels[1:]:
try:
lora_path = None
if os.path.exists(self.model_path):
# Local path
lora_path = os.path.join(self.model_path, "loras", task_name)
if not os.path.exists(lora_path):
print(f"⚠️ LoRA directory not found: {lora_path}")
continue
else:
# Hub path - download adapter
try:
adapter_file = hf_hub_download(
repo_id=self.model_path,
filename=f"loras/{task_name}/adapter_model.safetensors"
)
lora_path = os.path.dirname(adapter_file)
except Exception as e:
print(f"❌ Failed to download LoRA {task_name}: {e}")
continue
# Load adapter into the same PEFT model with unique name
peft_model.load_adapter(lora_path, adapter_name=task_name)
print(f"✅ Loaded LoRA: {task_name}")
except Exception as e:
print(f"❌ Failed to load LoRA {task_name}: {e}")
# Store single PEFT model for all adapters
self.lora_models = {str(name): peft_model for name in self.config.task_labels}
self._current_lora = first_adapter
self._current_adapted_model = peft_model
print(f"Loaded {len(self.config.task_labels)} LoRA adapters into one PEFT model.")
print(f"Available adapter names: {list(peft_model.peft_config.keys())}")
except Exception as e:
print(f"❌ Failed to initialize LoRA loading: {e}")
self.lora_models = {}
self._current_adapted_model = self.mola_model
self._current_lora = None
def predict_best_lora(self, text: str) -> str:
"""Predict the best LoRA adapter for given text."""
# Set models to eval mode
self.router_encoder.eval()
self.router_decoder.eval()
# Encode text
inputs = self.router_tokenizer(
[text],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# Move to device
device = next(self.router_decoder.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.router_encoder(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
logits = self.router_decoder(embeddings)
# Get best LoRA
best_idx = torch.argmax(logits, dim=-1).item()
predicted_label = self.config.task_labels[best_idx]
# Debug output
# print(f"Debug - Text: {text[:50]}...")
# print(f"Debug - Logits: {logits[0].cpu().numpy()}")
# print(f"Debug - Best idx: {best_idx}, Label: {predicted_label}")
return predicted_label
def _apply_lora(self, lora_name: str):
"""Apply the selected LoRA adapter using set_adapter."""
if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel):
# Map task labels to actual adapter names in PEFT model
# First adapter (task_labels[0]) is loaded as 'default', others keep their names
first_adapter = str(self.config.task_labels[0])
if str(lora_name) == first_adapter:
actual_adapter_name = "default"
else:
actual_adapter_name = str(lora_name)
# Check if the adapter exists in the PEFT model
if actual_adapter_name in self._current_adapted_model.peft_config:
if lora_name != self._current_lora:
self._current_adapted_model.set_adapter(actual_adapter_name)
self._current_lora = str(lora_name)
# print(f"🎯 Applied LoRA: {lora_name} (as {actual_adapter_name})") # Uncomment for debugging
else:
print(f"⚠️ LoRA adapter '{lora_name}' (mapped to '{actual_adapter_name}') not found in PEFT model. Available: {list(self._current_adapted_model.peft_config.keys())}")
# Keep current adapter if requested one doesn't exist
else:
# Fallback to base model if no PEFT model available
self._current_adapted_model = self.mola_model
self._current_lora = None
print(f"⚠️ No PEFT model available, using base model")
def get_available_loras(self) -> List[str]:
"""Get list of available LoRA adapter names."""
if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel):
return list(self._current_adapted_model.peft_config.keys())
else:
return []
def test_adapter_uniqueness(self, layer_name: str = "base_model.model.model.layers.33.mlp.down_proj"):
"""
Regression test to verify that adapters have different weights.
Args:
layer_name: The layer to test (default is a common MLP layer)
Returns:
Dict[str, str]: Mapping of adapter names to their weight hashes
"""
import hashlib
if not hasattr(self, '_current_adapted_model') or not isinstance(self._current_adapted_model, PeftModel):
print("⚠️ No PEFT model available for testing")
return {}
names = self.get_available_loras()
if len(names) <= 1:
print(f"⚠️ Need at least 2 adapters for uniqueness test, found {len(names)}")
return {}
def fused_sha(adapter_name, layer_name):
"""Compute SHA256 hash of fused LoRA weights for given adapter and layer."""
# Switch to the adapter
self._apply_lora(adapter_name)
# Navigate to the specified layer
try:
mod = self._current_adapted_model
for part in layer_name.split("."):
if part:
mod = getattr(mod, part)
# Get LoRA components
if not hasattr(mod, 'lora_A') or not hasattr(mod, 'lora_B'):
print(f"⚠️ Layer {layer_name} doesn't have LoRA components")
return "no_lora"
# Get the currently active adapter key from the PEFT model
active_adapter_key = self._current_adapted_model.active_adapter
# Check if the active adapter key exists
if active_adapter_key not in mod.lora_A:
print(f"⚠️ Active adapter key '{active_adapter_key}' not found. Available: {list(mod.lora_A.keys())}")
return f"missing_{adapter_name}"
A = mod.lora_A[active_adapter_key].weight
B = mod.lora_B[active_adapter_key].weight
s = float(mod.scaling[active_adapter_key])
# Compute fused weights: ΔW = (B @ A) * scaling
dW = (B @ A) * s
# Convert to bytes and hash
tensor_bytes = dW.detach().to("cpu", dtype=torch.float32).contiguous().numpy().tobytes()
return hashlib.sha256(tensor_bytes).hexdigest()[:16]
except Exception as e:
print(f"❌ Error computing hash for {adapter_name}: {e}")
return f"error_{adapter_name}"
print(f"🧪 Testing adapter uniqueness on layer: {layer_name}")
hashes = {}
for adapter_name in names:
hash_val = fused_sha(adapter_name, layer_name)
hashes[adapter_name] = hash_val
print(f" {adapter_name}: {hash_val}")
# Check uniqueness
unique_hashes = set(hashes.values())
if len(unique_hashes) == len(names):
print("✅ All adapters have unique weights!")
else:
print(f"❌ Found duplicate weights! {len(names)} adapters but only {len(unique_hashes)} unique hashes")
# Show which ones are identical
from collections import defaultdict
hash_to_adapters = defaultdict(list)
for adapter, hash_val in hashes.items():
hash_to_adapters[hash_val].append(adapter)
for hash_val, adapter_list in hash_to_adapters.items():
if len(adapter_list) > 1:
print(f" Identical weights (hash {hash_val}): {adapter_list}")
return hashes
def generate(self, input_ids=None, attention_mask=None, **kwargs):
"""
Standard generate method with automatic LoRA selection.
Works exactly like any other LLM's generate method.
"""
# If we have input_ids, predict and apply the best LoRA
if input_ids is not None and hasattr(self, 'tokenizer'):
try:
# Decode the input to get the text for LoRA prediction
if len(input_ids.shape) > 1:
# Batch input - use first item
text_input = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
else:
text_input = self.tokenizer.decode(input_ids, skip_special_tokens=True)
# Clean the text thoroughly to remove ALL chat template artifacts
import re
# Build regex patterns to avoid escape issues in embedded code
start_pattern = '<' + '|im_start|' + '>user'
end_pattern = '<' + '|im_end|' + '>'
# First, try to extract just the user's actual question/prompt
if start_pattern in text_input and end_pattern in text_input:
start_idx = text_input.find(start_pattern) + len(start_pattern)
end_idx = text_input.find(end_pattern, start_idx)
if end_idx > start_idx:
text_input = text_input[start_idx:end_idx].strip()
# Clean up any remaining template artifacts
# Remove special tokens with simple string replacement
text_input = text_input.replace('<|im_start|>', '')
text_input = text_input.replace('<|im_end|>', '')
text_input = text_input.replace('system', '')
text_input = text_input.replace('user', '')
text_input = text_input.replace('assistant', '')
# Remove system message patterns
if 'You are Qwen' in text_input:
lines = text_input.split('\n')
lines = [line for line in lines if 'You are' not in line and 'Alibaba' not in line]
text_input = ' '.join(lines)
# Final cleanup
text_input = re.sub(r'\n+', ' ', text_input) # Replace newlines with spaces
text_input = re.sub(r'\s+', ' ', text_input) # Normalize whitespace
text_input = text_input.strip()
# Debug: print the actual text being classified
# print(f"DEBUG RAW: '{self.tokenizer.decode(input_ids[0], skip_special_tokens=False)}'")
# print(f"DEBUG CLEAN: '{text_input}'")
# Predict and apply best LoRA
best_lora = self.predict_best_lora(text_input)
self._apply_lora(best_lora)
except Exception as e:
# If LoRA prediction fails, use base model
# print(f"DEBUG: LoRA prediction failed: {e}")
self._current_adapted_model = self.mola_model
self._current_lora = None
# Use the currently adapted model for generation
return self._current_adapted_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
def forward(self, input_ids, attention_mask=None, **kwargs):
"""Forward pass through the model."""
return self._current_adapted_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
def __call__(self, *args, **kwargs):
"""Make the model callable."""
return self._current_adapted_model(*args, **kwargs)
def get_input_embeddings(self):
"""Get the input embeddings."""
return self._current_adapted_model.get_input_embeddings()
def set_input_embeddings(self, value):
"""Set the input embeddings."""
self._current_adapted_model.set_input_embeddings(value)
# Also set for base model to keep them in sync
self.mola_model.set_input_embeddings(value)
def get_output_embeddings(self):
"""Get the output embeddings."""
return self._current_adapted_model.get_output_embeddings()
def set_output_embeddings(self, value):
"""Set the output embeddings."""
self._current_adapted_model.set_output_embeddings(value)
# Also set for base model to keep them in sync
self.mola_model.set_output_embeddings(value)
def tie_weights(self):
"""Tie input and output embeddings."""
self._current_adapted_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens):
"""Resize token embeddings."""
return self._current_adapted_model.resize_token_embeddings(new_num_tokens)
@property
def device(self):
"""Get the device of the model."""
return next(self.mola_model.parameters()).device
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""Load model from pretrained path (transformers compatibility)."""
# Load config
config = MoLAConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Store the path for resource loading
config._name_or_path = pretrained_model_name_or_path
return cls(config)
def save_pretrained(self, save_directory, **kwargs):
"""Save model using standard transformers approach."""
# Accept standard transformers parameters but use the ones we need
max_shard_size = kwargs.get('max_shard_size', "5GB")
safe_serialization = kwargs.get('safe_serialization', True)
os.makedirs(save_directory, exist_ok=True)
# Save config using transformers method
self.config.save_pretrained(save_directory)
# Save tokenizer if available
if hasattr(self, 'tokenizer'):
self.tokenizer.save_pretrained(save_directory)
# Save the base model with proper sharding if needed
try:
# Use the base model's save_pretrained with the parameters
self.mola_model.save_pretrained(
save_directory,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization
)
except Exception as e:
print(f"Warning: Could not save base model weights: {e}")
# Fallback: just save the config and tokenizer
pass
# Save router weights if they exist
try:
if hasattr(self, 'router_decoder'):
router_state_dict = self.router_decoder.state_dict()
torch.save(router_state_dict, os.path.join(save_directory, "router_weights.pth"))
except Exception as e:
print(f"Warning: Could not save router weights: {e}")
print(f"Model saved to {save_directory}")
def get_current_lora(self) -> str:
"""Get the currently applied LoRA adapter name."""
return self._current_lora or "base_model"
def get_available_loras(self) -> List[str]:
"""Get list of available LoRA adapters."""
return list(self.lora_models.keys())
# For transformers AutoModel registration
def _load_mola_model(model_path, **kwargs):
"""Helper function to load MoLA model."""
return MoLAForCausalLM.from_pretrained(model_path, **kwargs)
# Register with transformers AutoModel system
try:
CONFIG_MAPPING.register("mola_lm", MoLAConfig)
MODEL_FOR_CAUSAL_LM_MAPPING.register(MoLAConfig, MoLAForCausalLM)
print("✅ Successfully registered MoLA-LM with AutoModel!")
except Exception as e:
print(f"⚠️ AutoModel registration failed: {e}")
# Try alternative registration for backwards compatibility
try:
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("mola_lm", MoLAConfig)
AutoModelForCausalLM.register(MoLAConfig, MoLAForCausalLM)
print("✅ Successfully registered MoLA-LM with legacy method!")
except Exception as e2:
print(f"⚠️ Legacy registration also failed: {e2}")
print("Model can still be loaded directly with MoLAForCausalLM.from_pretrained()")