""" Simplified MoR (Mixture-of-Recursions) model implementation for Hugging Face Hub. This provides basic inference capabilities while maintaining compatibility with the full MoR framework. """ import torch import torch.nn as nn from typing import Optional, Tuple, Union from transformers import LlamaForCausalLM, LlamaConfig from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings class MoRConfig(LlamaConfig): """ Configuration class for MoR model. Extends LlamaConfig with MoR-specific parameters. """ def __init__( self, mor_enabled=True, num_recursions=3, routing_strategy="expert_choice", kv_sharing=None, **kwargs ): super().__init__(**kwargs) # MoR-specific configurations self.mor_enabled = mor_enabled self.num_recursions = num_recursions self.routing_strategy = routing_strategy self.kv_sharing = kv_sharing class MoRLlamaForCausalLM(LlamaForCausalLM): """ Simplified MoR model for Hugging Face Hub. This implementation provides basic inference capabilities while maintaining compatibility with the original MoR training framework. For full MoR features including dynamic routing and recursion-wise KV caching, use the complete implementation from the original repository. """ config_class = MoRConfig def __init__(self, config): super().__init__(config) # Store MoR-specific config self.mor_config = config # For simplified inference, we'll use the standard forward pass # Full MoR capabilities require the complete training framework @add_start_docstrings_to_model_forward("Standard forward pass with simplified MoR compatibility") def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: """ Forward pass for simplified MoR model. For basic inference, this behaves like a standard LLaMA model. Advanced MoR features require the complete training framework. """ # Use standard LLaMA forward pass for simplified inference return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load MoR model from pretrained checkpoint. This method handles loading the model weights while maintaining compatibility with both the simplified and full MoR implementations. """ # Load the model using the parent class method model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) return model def generate_with_mor( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, max_length: int = 100, temperature: float = 1.0, do_sample: bool = True, **kwargs ): """ Generate text with MoR-aware settings. This is a convenience method that provides optimized generation settings for MoR models. """ return self.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, do_sample=do_sample, pad_token_id=self.config.eos_token_id, **kwargs ) # Register the model for auto-loading try: from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("mor_llama", MoRConfig) AutoModelForCausalLM.register(MoRConfig, MoRLlamaForCausalLM) except: # Registration may fail in some environments, but the model can still be used directly pass