|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Dict, List, Optional, Union |
|
from transformers import ( |
|
AutoConfig, AutoTokenizer, AutoModelForCausalLM, |
|
PretrainedConfig, PreTrainedModel, GenerationMixin |
|
) |
|
from transformers.models.auto import CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING |
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
|
EXPERTS_LIST = [ |
|
"0", |
|
"1", |
|
"2", |
|
"3", |
|
"4", |
|
"5", |
|
"6", |
|
"7", |
|
"8", |
|
] |
|
|
|
|
|
class MoLAConfig(PretrainedConfig): |
|
"""Configuration class for MoLA-LM model.""" |
|
|
|
model_type = "mola_lm" |
|
|
|
def __init__( |
|
self, |
|
base_model_name_or_path: str = "Qwen/Qwen3-4B-Thinking-2507", |
|
task_labels: List[str] = None, |
|
router_config: Dict = None, |
|
lora_configs: Dict[str, Dict] = None, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.base_model_name_or_path = base_model_name_or_path |
|
self.task_labels = task_labels or EXPERTS_LIST |
|
self.router_config = router_config or {} |
|
self.lora_configs = lora_configs or {} |
|
self.num_loras = len(self.task_labels) |
|
|
|
|
|
class MoLAForCausalLM(PreTrainedModel, GenerationMixin): |
|
""" |
|
MoLA Language Model for Causal Language Modeling - AutoModel Compatible |
|
""" |
|
|
|
config_class = MoLAConfig |
|
base_model_prefix = "mola_model" |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.model_path = getattr(config, '_name_or_path', None) |
|
|
|
|
|
print(f"Loading base model: {self.config.base_model_name_or_path}") |
|
self.mola_model = AutoModelForCausalLM.from_pretrained( |
|
self.config.base_model_name_or_path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None |
|
) |
|
|
|
|
|
if self.model_path: |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
else: |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_name_or_path) |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self._init_router() |
|
|
|
|
|
self._current_lora = None |
|
self._current_adapted_model = self.mola_model |
|
|
|
|
|
self._load_lora_adapters() |
|
|
|
|
|
self._device = next(self.mola_model.parameters()).device |
|
|
|
|
|
self._load_router_weights() |
|
|
|
print("MoLA-LM initialized successfully!") |
|
|
|
def _load_router_weights(self): |
|
"""Load router weights from the saved checkpoint.""" |
|
if self.model_path: |
|
try: |
|
|
|
if os.path.exists(self.model_path): |
|
|
|
router_weights_path = os.path.join(self.model_path, "router_weights.pth") |
|
if os.path.exists(router_weights_path): |
|
checkpoint = torch.load(router_weights_path, map_location='cpu') |
|
else: |
|
print("⚠️ No router weights found locally") |
|
return |
|
else: |
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
router_weights_path = hf_hub_download( |
|
repo_id=self.model_path, |
|
filename="router_weights.pth", |
|
local_files_only=False |
|
) |
|
checkpoint = torch.load(router_weights_path, map_location='cpu') |
|
print("📥 Downloaded router weights from Hub") |
|
except Exception as hub_e: |
|
print(f"⚠️ Failed to download router weights from Hub: {hub_e}") |
|
print("🔄 Router will use random initialization (reduced performance)") |
|
return |
|
|
|
|
|
router_state_dict = {} |
|
for key, value in checkpoint.items(): |
|
if not key.startswith('encoder.'): |
|
router_state_dict[key] = value |
|
|
|
if router_state_dict: |
|
self.router_decoder.load_state_dict(router_state_dict, strict=False) |
|
print("✅ Loaded router weights successfully!") |
|
|
|
|
|
first_layer = next(iter(self.router_decoder.parameters())) |
|
if torch.all(first_layer == 0): |
|
print("⚠️ Warning: Router weights appear to be zero-initialized") |
|
else: |
|
print("🎯 Router weights verified - non-zero values detected") |
|
else: |
|
print("⚠️ No valid router weights found in checkpoint") |
|
|
|
except Exception as e: |
|
print(f"❌ Failed to load router weights: {e}") |
|
print("🔄 Router will use random initialization (reduced performance)") |
|
|
|
def _init_router(self): |
|
"""Initialize the router model for LoRA selection.""" |
|
try: |
|
from transformers import AutoModel |
|
|
|
print("Initializing router components...") |
|
|
|
self.router_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") |
|
self.router_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
for param in self.router_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
encoder_dim = self.router_encoder.config.hidden_size |
|
self.router_decoder = nn.Sequential( |
|
nn.Linear(encoder_dim, 768), |
|
nn.ReLU(), |
|
nn.Dropout(.3), |
|
nn.Linear(768, 768), |
|
nn.ReLU(), |
|
nn.Dropout(.3), |
|
nn.Linear(768, 768), |
|
nn.ReLU(), |
|
nn.Dropout(.3), |
|
nn.Linear(768, 384), |
|
nn.ReLU(), |
|
nn.Dropout(.3), |
|
nn.Linear(384, 192), |
|
nn.ReLU(), |
|
nn.Dropout(.3), |
|
nn.Linear(192, self.config.num_loras) |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
self.router_encoder = self.router_encoder.cuda() |
|
self.router_decoder = self.router_decoder.cuda() |
|
|
|
print("Router initialized successfully!") |
|
|
|
except ImportError as e: |
|
raise ImportError(f"Required dependencies not found: {e}") |
|
|
|
def _load_lora_adapters(self): |
|
"""Load LoRA adapters using PEFT (single wrapper, multiple adapters).""" |
|
from huggingface_hub import hf_hub_download |
|
|
|
if not self.model_path: |
|
print("No model path specified, skipping LoRA loading") |
|
return |
|
|
|
print("Loading LoRA adapters (single wrapper)...") |
|
|
|
|
|
first_adapter = str(self.config.task_labels[0]) |
|
first_lora_path = None |
|
|
|
try: |
|
|
|
if os.path.exists(self.model_path): |
|
|
|
first_lora_path = os.path.join(self.model_path, "loras", first_adapter) |
|
if not os.path.exists(first_lora_path): |
|
raise FileNotFoundError(f"First adapter directory not found: {first_lora_path}") |
|
else: |
|
|
|
try: |
|
|
|
adapter_file = hf_hub_download( |
|
repo_id=self.model_path, |
|
filename=f"loras/{first_adapter}/adapter_model.safetensors" |
|
) |
|
first_lora_path = os.path.dirname(adapter_file) |
|
print(f"Downloaded first adapter to: {first_lora_path}") |
|
except Exception as e: |
|
raise Exception(f"Failed to download first adapter {first_adapter}: {e}") |
|
|
|
|
|
peft_model = PeftModel.from_pretrained( |
|
self.mola_model, |
|
first_lora_path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
print(f"✅ Loaded first LoRA: {first_adapter} (as default)") |
|
|
|
|
|
for task_name in self.config.task_labels[1:]: |
|
try: |
|
lora_path = None |
|
|
|
if os.path.exists(self.model_path): |
|
|
|
lora_path = os.path.join(self.model_path, "loras", task_name) |
|
if not os.path.exists(lora_path): |
|
print(f"⚠️ LoRA directory not found: {lora_path}") |
|
continue |
|
else: |
|
|
|
try: |
|
adapter_file = hf_hub_download( |
|
repo_id=self.model_path, |
|
filename=f"loras/{task_name}/adapter_model.safetensors" |
|
) |
|
lora_path = os.path.dirname(adapter_file) |
|
except Exception as e: |
|
print(f"❌ Failed to download LoRA {task_name}: {e}") |
|
continue |
|
|
|
|
|
peft_model.load_adapter(lora_path, adapter_name=task_name) |
|
print(f"✅ Loaded LoRA: {task_name}") |
|
|
|
except Exception as e: |
|
print(f"❌ Failed to load LoRA {task_name}: {e}") |
|
|
|
|
|
self.lora_models = {str(name): peft_model for name in self.config.task_labels} |
|
self._current_lora = first_adapter |
|
self._current_adapted_model = peft_model |
|
|
|
print(f"Loaded {len(self.config.task_labels)} LoRA adapters into one PEFT model.") |
|
print(f"Available adapter names: {list(peft_model.peft_config.keys())}") |
|
|
|
except Exception as e: |
|
print(f"❌ Failed to initialize LoRA loading: {e}") |
|
self.lora_models = {} |
|
self._current_adapted_model = self.mola_model |
|
self._current_lora = None |
|
|
|
def predict_best_lora(self, text: str) -> str: |
|
"""Predict the best LoRA adapter for given text.""" |
|
|
|
self.router_encoder.eval() |
|
self.router_decoder.eval() |
|
|
|
|
|
inputs = self.router_tokenizer( |
|
[text], |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
device = next(self.router_decoder.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = self.router_encoder(**inputs) |
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
logits = self.router_decoder(embeddings) |
|
|
|
|
|
best_idx = torch.argmax(logits, dim=-1).item() |
|
predicted_label = self.config.task_labels[best_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
return predicted_label |
|
|
|
def _apply_lora(self, lora_name: str): |
|
"""Apply the selected LoRA adapter using set_adapter.""" |
|
if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel): |
|
|
|
|
|
first_adapter = str(self.config.task_labels[0]) |
|
if str(lora_name) == first_adapter: |
|
actual_adapter_name = "default" |
|
else: |
|
actual_adapter_name = str(lora_name) |
|
|
|
|
|
if actual_adapter_name in self._current_adapted_model.peft_config: |
|
if lora_name != self._current_lora: |
|
self._current_adapted_model.set_adapter(actual_adapter_name) |
|
self._current_lora = str(lora_name) |
|
|
|
else: |
|
print(f"⚠️ LoRA adapter '{lora_name}' (mapped to '{actual_adapter_name}') not found in PEFT model. Available: {list(self._current_adapted_model.peft_config.keys())}") |
|
|
|
else: |
|
|
|
self._current_adapted_model = self.mola_model |
|
self._current_lora = None |
|
print(f"⚠️ No PEFT model available, using base model") |
|
|
|
def get_available_loras(self) -> List[str]: |
|
"""Get list of available LoRA adapter names.""" |
|
if hasattr(self, '_current_adapted_model') and isinstance(self._current_adapted_model, PeftModel): |
|
return list(self._current_adapted_model.peft_config.keys()) |
|
else: |
|
return [] |
|
|
|
def test_adapter_uniqueness(self, layer_name: str = "base_model.model.model.layers.33.mlp.down_proj"): |
|
""" |
|
Regression test to verify that adapters have different weights. |
|
|
|
Args: |
|
layer_name: The layer to test (default is a common MLP layer) |
|
|
|
Returns: |
|
Dict[str, str]: Mapping of adapter names to their weight hashes |
|
""" |
|
import hashlib |
|
|
|
if not hasattr(self, '_current_adapted_model') or not isinstance(self._current_adapted_model, PeftModel): |
|
print("⚠️ No PEFT model available for testing") |
|
return {} |
|
|
|
names = self.get_available_loras() |
|
if len(names) <= 1: |
|
print(f"⚠️ Need at least 2 adapters for uniqueness test, found {len(names)}") |
|
return {} |
|
|
|
def fused_sha(adapter_name, layer_name): |
|
"""Compute SHA256 hash of fused LoRA weights for given adapter and layer.""" |
|
|
|
self._apply_lora(adapter_name) |
|
|
|
|
|
try: |
|
mod = self._current_adapted_model |
|
for part in layer_name.split("."): |
|
if part: |
|
mod = getattr(mod, part) |
|
|
|
|
|
if not hasattr(mod, 'lora_A') or not hasattr(mod, 'lora_B'): |
|
print(f"⚠️ Layer {layer_name} doesn't have LoRA components") |
|
return "no_lora" |
|
|
|
|
|
active_adapter_key = self._current_adapted_model.active_adapter |
|
|
|
|
|
if active_adapter_key not in mod.lora_A: |
|
print(f"⚠️ Active adapter key '{active_adapter_key}' not found. Available: {list(mod.lora_A.keys())}") |
|
return f"missing_{adapter_name}" |
|
|
|
A = mod.lora_A[active_adapter_key].weight |
|
B = mod.lora_B[active_adapter_key].weight |
|
s = float(mod.scaling[active_adapter_key]) |
|
|
|
|
|
dW = (B @ A) * s |
|
|
|
|
|
tensor_bytes = dW.detach().to("cpu", dtype=torch.float32).contiguous().numpy().tobytes() |
|
return hashlib.sha256(tensor_bytes).hexdigest()[:16] |
|
|
|
except Exception as e: |
|
print(f"❌ Error computing hash for {adapter_name}: {e}") |
|
return f"error_{adapter_name}" |
|
|
|
print(f"🧪 Testing adapter uniqueness on layer: {layer_name}") |
|
hashes = {} |
|
for adapter_name in names: |
|
hash_val = fused_sha(adapter_name, layer_name) |
|
hashes[adapter_name] = hash_val |
|
print(f" {adapter_name}: {hash_val}") |
|
|
|
|
|
unique_hashes = set(hashes.values()) |
|
if len(unique_hashes) == len(names): |
|
print("✅ All adapters have unique weights!") |
|
else: |
|
print(f"❌ Found duplicate weights! {len(names)} adapters but only {len(unique_hashes)} unique hashes") |
|
|
|
from collections import defaultdict |
|
hash_to_adapters = defaultdict(list) |
|
for adapter, hash_val in hashes.items(): |
|
hash_to_adapters[hash_val].append(adapter) |
|
|
|
for hash_val, adapter_list in hash_to_adapters.items(): |
|
if len(adapter_list) > 1: |
|
print(f" Identical weights (hash {hash_val}): {adapter_list}") |
|
|
|
return hashes |
|
|
|
def generate(self, input_ids=None, attention_mask=None, **kwargs): |
|
""" |
|
Standard generate method with automatic LoRA selection. |
|
Works exactly like any other LLM's generate method. |
|
""" |
|
|
|
if input_ids is not None and hasattr(self, 'tokenizer'): |
|
try: |
|
|
|
if len(input_ids.shape) > 1: |
|
|
|
text_input = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
else: |
|
text_input = self.tokenizer.decode(input_ids, skip_special_tokens=True) |
|
|
|
|
|
import re |
|
|
|
|
|
start_pattern = '<' + '|im_start|' + '>user' |
|
end_pattern = '<' + '|im_end|' + '>' |
|
|
|
|
|
if start_pattern in text_input and end_pattern in text_input: |
|
start_idx = text_input.find(start_pattern) + len(start_pattern) |
|
end_idx = text_input.find(end_pattern, start_idx) |
|
if end_idx > start_idx: |
|
text_input = text_input[start_idx:end_idx].strip() |
|
|
|
|
|
|
|
text_input = text_input.replace('<|im_start|>', '') |
|
text_input = text_input.replace('<|im_end|>', '') |
|
text_input = text_input.replace('system', '') |
|
text_input = text_input.replace('user', '') |
|
text_input = text_input.replace('assistant', '') |
|
|
|
|
|
if 'You are Qwen' in text_input: |
|
lines = text_input.split('\n') |
|
lines = [line for line in lines if 'You are' not in line and 'Alibaba' not in line] |
|
text_input = ' '.join(lines) |
|
|
|
|
|
text_input = re.sub(r'\n+', ' ', text_input) |
|
text_input = re.sub(r'\s+', ' ', text_input) |
|
text_input = text_input.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
best_lora = self.predict_best_lora(text_input) |
|
self._apply_lora(best_lora) |
|
|
|
except Exception as e: |
|
|
|
|
|
self._current_adapted_model = self.mola_model |
|
self._current_lora = None |
|
|
|
|
|
return self._current_adapted_model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
**kwargs |
|
) |
|
|
|
def forward(self, input_ids, attention_mask=None, **kwargs): |
|
"""Forward pass through the model.""" |
|
return self._current_adapted_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) |
|
|
|
def __call__(self, *args, **kwargs): |
|
"""Make the model callable.""" |
|
return self._current_adapted_model(*args, **kwargs) |
|
|
|
def get_input_embeddings(self): |
|
"""Get the input embeddings.""" |
|
return self._current_adapted_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set the input embeddings.""" |
|
self._current_adapted_model.set_input_embeddings(value) |
|
|
|
self.mola_model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
"""Get the output embeddings.""" |
|
return self._current_adapted_model.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, value): |
|
"""Set the output embeddings.""" |
|
self._current_adapted_model.set_output_embeddings(value) |
|
|
|
self.mola_model.set_output_embeddings(value) |
|
|
|
def tie_weights(self): |
|
"""Tie input and output embeddings.""" |
|
self._current_adapted_model.tie_weights() |
|
|
|
def resize_token_embeddings(self, new_num_tokens): |
|
"""Resize token embeddings.""" |
|
return self._current_adapted_model.resize_token_embeddings(new_num_tokens) |
|
|
|
@property |
|
def device(self): |
|
"""Get the device of the model.""" |
|
return next(self.mola_model.parameters()).device |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
"""Load model from pretrained path (transformers compatibility).""" |
|
|
|
config = MoLAConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
config._name_or_path = pretrained_model_name_or_path |
|
return cls(config) |
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
"""Save model using standard transformers approach.""" |
|
|
|
max_shard_size = kwargs.get('max_shard_size', "5GB") |
|
safe_serialization = kwargs.get('safe_serialization', True) |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
if hasattr(self, 'tokenizer'): |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
|
|
try: |
|
|
|
self.mola_model.save_pretrained( |
|
save_directory, |
|
max_shard_size=max_shard_size, |
|
safe_serialization=safe_serialization |
|
) |
|
except Exception as e: |
|
print(f"Warning: Could not save base model weights: {e}") |
|
|
|
pass |
|
|
|
|
|
try: |
|
if hasattr(self, 'router_decoder'): |
|
router_state_dict = self.router_decoder.state_dict() |
|
torch.save(router_state_dict, os.path.join(save_directory, "router_weights.pth")) |
|
except Exception as e: |
|
print(f"Warning: Could not save router weights: {e}") |
|
|
|
print(f"Model saved to {save_directory}") |
|
|
|
def get_current_lora(self) -> str: |
|
"""Get the currently applied LoRA adapter name.""" |
|
return self._current_lora or "base_model" |
|
|
|
def get_available_loras(self) -> List[str]: |
|
"""Get list of available LoRA adapters.""" |
|
return list(self.lora_models.keys()) |
|
|
|
|
|
|
|
def _load_mola_model(model_path, **kwargs): |
|
"""Helper function to load MoLA model.""" |
|
return MoLAForCausalLM.from_pretrained(model_path, **kwargs) |
|
|
|
|
|
|
|
try: |
|
CONFIG_MAPPING.register("mola_lm", MoLAConfig) |
|
MODEL_FOR_CAUSAL_LM_MAPPING.register(MoLAConfig, MoLAForCausalLM) |
|
print("✅ Successfully registered MoLA-LM with AutoModel!") |
|
except Exception as e: |
|
print(f"⚠️ AutoModel registration failed: {e}") |
|
|
|
try: |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
AutoConfig.register("mola_lm", MoLAConfig) |
|
AutoModelForCausalLM.register(MoLAConfig, MoLAForCausalLM) |
|
print("✅ Successfully registered MoLA-LM with legacy method!") |
|
except Exception as e2: |
|
print(f"⚠️ Legacy registration also failed: {e2}") |
|
print("Model can still be loaded directly with MoLAForCausalLM.from_pretrained()") |
|
|