File size: 4,946 Bytes
861e577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Simplified MoR (Mixture-of-Recursions) model implementation for Hugging Face Hub.
This provides basic inference capabilities while maintaining compatibility with the full MoR framework.
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings

class MoRConfig(LlamaConfig):
    """
    Configuration class for MoR model.
    Extends LlamaConfig with MoR-specific parameters.
    """
    
    def __init__(
        self,
        mor_enabled=True,
        num_recursions=3,
        routing_strategy="expert_choice",
        kv_sharing=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        # MoR-specific configurations
        self.mor_enabled = mor_enabled
        self.num_recursions = num_recursions
        self.routing_strategy = routing_strategy
        self.kv_sharing = kv_sharing

class MoRLlamaForCausalLM(LlamaForCausalLM):
    """
    Simplified MoR model for Hugging Face Hub.
    
    This implementation provides basic inference capabilities while maintaining
    compatibility with the original MoR training framework. For full MoR features
    including dynamic routing and recursion-wise KV caching, use the complete
    implementation from the original repository.
    """
    
    config_class = MoRConfig
    
    def __init__(self, config):
        super().__init__(config)
        
        # Store MoR-specific config
        self.mor_config = config
        
        # For simplified inference, we'll use the standard forward pass
        # Full MoR capabilities require the complete training framework
        
    @add_start_docstrings_to_model_forward("Standard forward pass with simplified MoR compatibility")
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """
        Forward pass for simplified MoR model.
        
        For basic inference, this behaves like a standard LLaMA model.
        Advanced MoR features require the complete training framework.
        """
        
        # Use standard LLaMA forward pass for simplified inference
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Load MoR model from pretrained checkpoint.
        
        This method handles loading the model weights while maintaining
        compatibility with both the simplified and full MoR implementations.
        """
        
        # Load the model using the parent class method
        model = super().from_pretrained(
            pretrained_model_name_or_path, 
            *model_args, 
            **kwargs
        )
        
        return model
    
    def generate_with_mor(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        max_length: int = 100,
        temperature: float = 1.0,
        do_sample: bool = True,
        **kwargs
    ):
        """
        Generate text with MoR-aware settings.
        
        This is a convenience method that provides optimized generation
        settings for MoR models.
        """
        
        return self.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=self.config.eos_token_id,
            **kwargs
        )

# Register the model for auto-loading
try:
    from transformers import AutoConfig, AutoModelForCausalLM
    AutoConfig.register("mor_llama", MoRConfig)
    AutoModelForCausalLM.register(MoRConfig, MoRLlamaForCausalLM)
except:
    # Registration may fail in some environments, but the model can still be used directly
    pass