File size: 12,173 Bytes
95f1c8b
 
 
 
 
 
 
3bb8d99
 
95f1c8b
3bb8d99
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
 
 
 
 
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
0afa2f5
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
 
 
0afa2f5
95f1c8b
 
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
 
 
 
 
 
 
 
 
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
 
 
 
0afa2f5
 
95f1c8b
 
0afa2f5
 
 
 
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
 
 
 
 
95f1c8b
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
0afa2f5
 
 
 
95f1c8b
 
 
3bb8d99
0afa2f5
95f1c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
 
 
 
0afa2f5
95f1c8b
 
 
 
 
 
 
 
 
 
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
 
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
 
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
95f1c8b
 
 
 
0afa2f5
95f1c8b
 
 
 
 
0afa2f5
95f1c8b
 
 
0afa2f5
 
 
 
 
 
 
95f1c8b
0afa2f5
95f1c8b
 
3bb8d99
0afa2f5
95f1c8b
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
0afa2f5
 
95f1c8b
 
 
 
0afa2f5
95f1c8b
0afa2f5
95f1c8b
 
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
 
 
 
 
0afa2f5
95f1c8b
 
 
 
 
 
0afa2f5
95f1c8b
 
 
 
3bb8d99
95f1c8b
 
0afa2f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
Wisent-enhanced Qwen2 model with integrated CAA (Contrastive Activation Addition) steering.

This model automatically applies CAA steering during generation without requiring manual hooks.
The steering parameters are optimized using Optuna and stored in the model configuration.
"""

from typing import List, Optional, Tuple, Union

import torch
from transformers import Qwen2Config, Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast


class WisentQwen2Config(Qwen2Config):
    """Extended Qwen2 configuration with CAA steering parameters."""

    model_type = "wisent_qwen2"

    def __init__(
        self,
        caa_enabled: bool = True,
        caa_layer_id: int = 24,
        caa_alpha: float = 0.9,
        steering_vector_path: str = "./vectors/coding/steering_vector.safetensors",
        steering_method: str = "caa",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.caa_enabled = caa_enabled
        self.caa_layer_id = caa_layer_id
        self.caa_alpha = caa_alpha
        self.steering_vector_path = steering_vector_path
        self.steering_method = steering_method


class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
    """
    Qwen2 model with integrated CAA steering for improved code generation.

    This model automatically applies Contrastive Activation Addition (CAA) steering
    during the forward pass, eliminating the need for manual hook management.
    """

    config_class = WisentQwen2Config

    def __init__(self, config: WisentQwen2Config):
        super().__init__(config)

        # CAA steering parameters
        self.caa_enabled = config.caa_enabled
        self.caa_layer_id = config.caa_layer_id
        self.caa_alpha = config.caa_alpha
        self.steering_method = config.steering_method

        # Load steering vector from file
        self.steering_vector = None
        if self.caa_enabled:
            self._load_steering_vector_from_file(config.steering_vector_path)

        # Hook handle for cleanup
        self._steering_hook_handle = None

    def _load_steering_vector_from_file(self, path: str):
        """Load the CAA steering vector from safetensors or pytorch file."""
        import os

        try:
            # Try relative path first
            if os.path.exists(path):
                vector_path = path
            # Try path relative to model directory
            elif os.path.exists(os.path.join(os.path.dirname(__file__), path)):
                vector_path = os.path.join(os.path.dirname(__file__), path)
            else:
                print(f"Warning: Steering vector not found at {path}, CAA disabled")
                self.caa_enabled = False
                return

            # Load based on file extension
            if vector_path.endswith(".safetensors"):
                # Load from safetensors format (preferred)
                try:
                    from safetensors.torch import load_file

                    steering_data = load_file(vector_path)
                    self.steering_vector = steering_data["steering_vector"]
                except ImportError:
                    print("Warning: safetensors not installed, install with: pip install safetensors")
                    self.caa_enabled = False
                    return
            else:
                # Load from pytorch format (fallback)
                steering_data = torch.load(vector_path, map_location="cpu")

                # Handle different storage formats
                if isinstance(steering_data, dict):
                    if "vector" in steering_data:
                        self.steering_vector = steering_data["vector"]
                    elif "steering_vector" in steering_data:
                        self.steering_vector = steering_data["steering_vector"]
                    else:
                        # Assume the dict values are the vectors
                        self.steering_vector = next(iter(steering_data.values()))
                else:
                    self.steering_vector = steering_data

            # Ensure it's a tensor
            if not isinstance(self.steering_vector, torch.Tensor):
                self.steering_vector = torch.tensor(self.steering_vector)

            print(
                f"✅ Loaded CAA steering vector from {vector_path}: shape {self.steering_vector.shape}, norm {torch.norm(self.steering_vector).item():.4f}"
            )

        except Exception as e:
            print(f"Warning: Failed to load steering vector: {e}, CAA disabled")
            self.caa_enabled = False
            self.steering_vector = None

    def _apply_caa_steering(self, module, input, output):
        """
        Hook function that applies CAA steering to the specified layer.

        This follows the implementation from wisent_guard/core/steering_methods/caa.py
        and the patterns from wisent_guard/core/optuna/optuna_pipeline.py
        """
        if not self.caa_enabled or self.steering_vector is None:
            return output

        # Extract hidden states from output
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output

        # Apply steering to the last token position (standard CAA behavior)
        # This matches the implementation in optuna_pipeline.py lines 744-746
        if hidden_states.dim() == 3:  # [batch, seq, hidden]
            # Move steering vector to the same device and dtype
            steering_vector = self.steering_vector.to(hidden_states.device, hidden_states.dtype)

            # Apply steering with configured alpha (strength)
            # Steering is applied to the last token position
            hidden_states[:, -1:, :] = hidden_states[:, -1:, :] + self.caa_alpha * steering_vector.unsqueeze(
                0
            ).unsqueeze(0)

        # Return modified output
        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        return hidden_states

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """
        Forward pass with automatic CAA steering application.

        The steering is applied via a forward hook on the specified layer,
        following the pattern from optuna_pipeline.py.
        """

        # Register CAA steering hook if enabled and not already registered
        if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
            target_layer = self.model.layers[self.caa_layer_id]
            self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)

        # Call parent forward method
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position if hasattr(self, "cache_position") else None,
        )

        return outputs

    def generate(self, *args, **kwargs):
        """
        Generate method with automatic CAA steering.

        The steering hook is registered before generation and cleaned up after.
        """
        # Register hook if needed
        if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
            target_layer = self.model.layers[self.caa_layer_id]
            self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)

        try:
            # Call parent generate method
            outputs = super().generate(*args, **kwargs)
        finally:
            # Clean up hook after generation
            if self._steering_hook_handle is not None:
                self._steering_hook_handle.remove()
                self._steering_hook_handle = None

        return outputs

    def set_caa_enabled(self, enabled: bool):
        """Enable or disable CAA steering at runtime."""
        self.caa_enabled = enabled
        if not enabled and self._steering_hook_handle is not None:
            self._steering_hook_handle.remove()
            self._steering_hook_handle = None

    def set_caa_alpha(self, alpha: float):
        """Adjust CAA steering strength at runtime."""
        self.caa_alpha = alpha

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Load model with automatic CAA configuration.

        This method ensures the steering vector is loaded from the embedded config.
        If no weights are found locally, it loads from the base Qwen model.
        """
        import os
        from pathlib import Path

        # Check if we have local weights
        local_path = Path(pretrained_model_name_or_path)
        has_weights = any(
            (local_path / f).exists()
            for f in [
                "pytorch_model.bin",
                "model.safetensors",
                "pytorch_model.bin.index.json",
                "model.safetensors.index.json",
            ]
        )

        if not has_weights and local_path.exists() and (local_path / "config.json").exists():
            # We have config but no weights - load from base model
            print("Loading weights from base model: Qwen/Qwen2.5-Coder-7B-Instruct")

            # First, load config from local path
            from transformers import AutoConfig

            config = AutoConfig.from_pretrained(pretrained_model_name_or_path)

            # Load model with base weights
            # Remove config from kwargs if it exists to avoid conflict
            kwargs_copy = kwargs.copy()
            kwargs_copy.pop("config", None)

            model = super().from_pretrained(
                "Qwen/Qwen2.5-Coder-7B-Instruct",
                *model_args,
                config=config,  # Use our custom config
                **kwargs_copy,
            )

            # Initialize CAA components
            model.caa_enabled = config.caa_enabled
            model.caa_layer_id = config.caa_layer_id
            model.caa_alpha = config.caa_alpha
            model.steering_method = config.steering_method
            model._steering_hook_handle = None

            # Load steering vector from config
            if model.caa_enabled:
                vector_path = config.steering_vector_path
                if not os.path.isabs(vector_path):
                    vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
                model._load_steering_vector_from_file(vector_path)
        else:
            # Standard loading path
            model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

            # Load steering vector from config if not already loaded
            if model.caa_enabled and model.steering_vector is None:
                vector_path = model.config.steering_vector_path
                if not os.path.isabs(vector_path):
                    vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
                model._load_steering_vector_from_file(vector_path)

        return model


# Register the model
from transformers import AutoConfig, AutoModelForCausalLM

AutoConfig.register("wisent_qwen2", WisentQwen2Config)
AutoModelForCausalLM.register(WisentQwen2Config, WisentQwen2ForCausalLM)