Upload folder using huggingface_hub
Browse files- README.md +9 -9
- config.json +1 -1
- modeling_wisent_qwen.py +70 -58
- vectors/mbpp_plus/steering_vector.safetensors +3 -0
README.md
CHANGED
@@ -76,8 +76,8 @@ pip install -r requirements.txt
|
|
76 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
77 |
|
78 |
# Load model - CAA steering is automatically applied!
|
79 |
-
model = AutoModelForCausalLM.from_pretrained("./
|
80 |
-
tokenizer = AutoTokenizer.from_pretrained("./
|
81 |
|
82 |
# Generate code
|
83 |
prompt = "Write a Python function to calculate the factorial of a number"
|
@@ -123,7 +123,7 @@ The model uses a trait-based organization for steering vectors:
|
|
123 |
|
124 |
```
|
125 |
vectors/
|
126 |
-
├──
|
127 |
├── safety/ # Future: Safety-aligned behavior
|
128 |
├── creativity/ # Future: Enhanced creative outputs
|
129 |
├── helpfulness/ # Future: Improved helpfulness
|
@@ -147,7 +147,7 @@ To switch traits, simply update the configuration:
|
|
147 |
- **Steering Strength (α)**: 0.9
|
148 |
- **Vector Format**: Safetensors format for efficient loading and HuggingFace compatibility
|
149 |
- **Vector Dimension**: 3584 (pre-normalized during training)
|
150 |
-
- **Storage Path**: `./vectors/
|
151 |
|
152 |
### How It Works
|
153 |
|
@@ -172,21 +172,21 @@ The CAA parameters were optimized using:
|
|
172 |
WisentQwen2ForCausalLM
|
173 |
├── Base: Qwen2.5-Coder-7B-Instruct
|
174 |
├── CAA Integration: Layer 24
|
175 |
-
├── Steering Vector: ./vectors/
|
176 |
└── Auto-applied during generation
|
177 |
```
|
178 |
|
179 |
## File Structure
|
180 |
|
181 |
```
|
182 |
-
|
183 |
├── config.json # Model configuration with CAA params
|
184 |
├── modeling_wisent_qwen.py # Custom model class
|
185 |
├── tokenizer files # Standard Qwen tokenizer
|
186 |
├── wisent_config.json # Optimization results
|
187 |
└── vectors/ # Trait-based steering vectors
|
188 |
-
└──
|
189 |
-
└── steering_vector.safetensors #
|
190 |
```
|
191 |
|
192 |
## Evaluation
|
@@ -202,7 +202,7 @@ The model should be evaluated on the complete MBPP Plus dataset (378 problems) t
|
|
202 |
from transformers import AutoModelForCausalLM
|
203 |
|
204 |
model = AutoModelForCausalLM.from_pretrained(
|
205 |
-
"./
|
206 |
trust_remote_code=True
|
207 |
)
|
208 |
|
|
|
76 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
77 |
|
78 |
# Load model - CAA steering is automatically applied!
|
79 |
+
model = AutoModelForCausalLM.from_pretrained("./huggingface_qwen25-7b-coder-caa", trust_remote_code=True)
|
80 |
+
tokenizer = AutoTokenizer.from_pretrained("./huggingface_qwen25-7b-coder-caa")
|
81 |
|
82 |
# Generate code
|
83 |
prompt = "Write a Python function to calculate the factorial of a number"
|
|
|
123 |
|
124 |
```
|
125 |
vectors/
|
126 |
+
├── mbpp_plus/ # Current: Optimized for MBPP Plus benchmark
|
127 |
├── safety/ # Future: Safety-aligned behavior
|
128 |
├── creativity/ # Future: Enhanced creative outputs
|
129 |
├── helpfulness/ # Future: Improved helpfulness
|
|
|
147 |
- **Steering Strength (α)**: 0.9
|
148 |
- **Vector Format**: Safetensors format for efficient loading and HuggingFace compatibility
|
149 |
- **Vector Dimension**: 3584 (pre-normalized during training)
|
150 |
+
- **Storage Path**: `./vectors/mbpp_plus/steering_vector.safetensors`
|
151 |
|
152 |
### How It Works
|
153 |
|
|
|
172 |
WisentQwen2ForCausalLM
|
173 |
├── Base: Qwen2.5-Coder-7B-Instruct
|
174 |
├── CAA Integration: Layer 24
|
175 |
+
├── Steering Vector: ./vectors/mbpp_plus/steering_vector.safetensors
|
176 |
└── Auto-applied during generation
|
177 |
```
|
178 |
|
179 |
## File Structure
|
180 |
|
181 |
```
|
182 |
+
huggingface_qwen25-7b-coder-caa/
|
183 |
├── config.json # Model configuration with CAA params
|
184 |
├── modeling_wisent_qwen.py # Custom model class
|
185 |
├── tokenizer files # Standard Qwen tokenizer
|
186 |
├── wisent_config.json # Optimization results
|
187 |
└── vectors/ # Trait-based steering vectors
|
188 |
+
└── mbpp_plus/
|
189 |
+
└── steering_vector.safetensors # MBPP Plus optimized steering vector
|
190 |
```
|
191 |
|
192 |
## Evaluation
|
|
|
202 |
from transformers import AutoModelForCausalLM
|
203 |
|
204 |
model = AutoModelForCausalLM.from_pretrained(
|
205 |
+
"./huggingface_qwen25-7b-coder-caa",
|
206 |
trust_remote_code=True
|
207 |
)
|
208 |
|
config.json
CHANGED
@@ -123,5 +123,5 @@
|
|
123 |
"timestamp": "20250818_221712",
|
124 |
"commit_hash": "a2181df6155f0d5d20170f307b61d10e74d31889"
|
125 |
},
|
126 |
-
"steering_vector_path": "./vectors/
|
127 |
}
|
|
|
123 |
"timestamp": "20250818_221712",
|
124 |
"commit_hash": "a2181df6155f0d5d20170f307b61d10e74d31889"
|
125 |
},
|
126 |
+
"steering_vector_path": "./vectors/mbpp_plus/steering_vector.safetensors"
|
127 |
}
|
modeling_wisent_qwen.py
CHANGED
@@ -15,9 +15,9 @@ from transformers.cache_utils import Cache
|
|
15 |
|
16 |
class WisentQwen2Config(Qwen2Config):
|
17 |
"""Extended Qwen2 configuration with CAA steering parameters."""
|
18 |
-
|
19 |
model_type = "wisent_qwen2"
|
20 |
-
|
21 |
def __init__(
|
22 |
self,
|
23 |
caa_enabled: bool = True,
|
@@ -25,7 +25,7 @@ class WisentQwen2Config(Qwen2Config):
|
|
25 |
caa_alpha: float = 0.9,
|
26 |
steering_vector_path: str = "./vectors/coding/steering_vector.safetensors",
|
27 |
steering_method: str = "caa",
|
28 |
-
**kwargs
|
29 |
):
|
30 |
super().__init__(**kwargs)
|
31 |
self.caa_enabled = caa_enabled
|
@@ -38,33 +38,34 @@ class WisentQwen2Config(Qwen2Config):
|
|
38 |
class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
39 |
"""
|
40 |
Qwen2 model with integrated CAA steering for improved code generation.
|
41 |
-
|
42 |
This model automatically applies Contrastive Activation Addition (CAA) steering
|
43 |
during the forward pass, eliminating the need for manual hook management.
|
44 |
"""
|
45 |
-
|
46 |
config_class = WisentQwen2Config
|
47 |
-
|
48 |
def __init__(self, config: WisentQwen2Config):
|
49 |
super().__init__(config)
|
50 |
-
|
51 |
# CAA steering parameters
|
52 |
self.caa_enabled = config.caa_enabled
|
53 |
self.caa_layer_id = config.caa_layer_id
|
54 |
self.caa_alpha = config.caa_alpha
|
55 |
self.steering_method = config.steering_method
|
56 |
-
|
57 |
# Load steering vector from file
|
58 |
self.steering_vector = None
|
59 |
if self.caa_enabled:
|
60 |
self._load_steering_vector_from_file(config.steering_vector_path)
|
61 |
-
|
62 |
# Hook handle for cleanup
|
63 |
self._steering_hook_handle = None
|
64 |
-
|
65 |
def _load_steering_vector_from_file(self, path: str):
|
66 |
"""Load the CAA steering vector from safetensors or pytorch file."""
|
67 |
import os
|
|
|
68 |
try:
|
69 |
# Try relative path first
|
70 |
if os.path.exists(path):
|
@@ -76,77 +77,82 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
76 |
print(f"Warning: Steering vector not found at {path}, CAA disabled")
|
77 |
self.caa_enabled = False
|
78 |
return
|
79 |
-
|
80 |
# Load based on file extension
|
81 |
-
if vector_path.endswith(
|
82 |
# Load from safetensors format (preferred)
|
83 |
try:
|
84 |
from safetensors.torch import load_file
|
|
|
85 |
steering_data = load_file(vector_path)
|
86 |
-
self.steering_vector = steering_data[
|
87 |
except ImportError:
|
88 |
print("Warning: safetensors not installed, install with: pip install safetensors")
|
89 |
self.caa_enabled = False
|
90 |
return
|
91 |
else:
|
92 |
# Load from pytorch format (fallback)
|
93 |
-
steering_data = torch.load(vector_path, map_location=
|
94 |
-
|
95 |
# Handle different storage formats
|
96 |
if isinstance(steering_data, dict):
|
97 |
-
if
|
98 |
-
self.steering_vector = steering_data[
|
99 |
-
elif
|
100 |
-
self.steering_vector = steering_data[
|
101 |
else:
|
102 |
# Assume the dict values are the vectors
|
103 |
self.steering_vector = next(iter(steering_data.values()))
|
104 |
else:
|
105 |
self.steering_vector = steering_data
|
106 |
-
|
107 |
# Ensure it's a tensor
|
108 |
if not isinstance(self.steering_vector, torch.Tensor):
|
109 |
self.steering_vector = torch.tensor(self.steering_vector)
|
110 |
-
|
111 |
-
print(
|
112 |
-
|
|
|
|
|
113 |
except Exception as e:
|
114 |
print(f"Warning: Failed to load steering vector: {e}, CAA disabled")
|
115 |
self.caa_enabled = False
|
116 |
self.steering_vector = None
|
117 |
-
|
118 |
def _apply_caa_steering(self, module, input, output):
|
119 |
"""
|
120 |
Hook function that applies CAA steering to the specified layer.
|
121 |
-
|
122 |
This follows the implementation from wisent_guard/core/steering_methods/caa.py
|
123 |
and the patterns from wisent_guard/core/optuna/optuna_pipeline.py
|
124 |
"""
|
125 |
if not self.caa_enabled or self.steering_vector is None:
|
126 |
return output
|
127 |
-
|
128 |
# Extract hidden states from output
|
129 |
if isinstance(output, tuple):
|
130 |
hidden_states = output[0]
|
131 |
else:
|
132 |
hidden_states = output
|
133 |
-
|
134 |
# Apply steering to the last token position (standard CAA behavior)
|
135 |
# This matches the implementation in optuna_pipeline.py lines 744-746
|
136 |
if hidden_states.dim() == 3: # [batch, seq, hidden]
|
137 |
# Move steering vector to the same device and dtype
|
138 |
steering_vector = self.steering_vector.to(hidden_states.device, hidden_states.dtype)
|
139 |
-
|
140 |
# Apply steering with configured alpha (strength)
|
141 |
# Steering is applied to the last token position
|
142 |
-
hidden_states[:, -1:, :] = hidden_states[:, -1:, :] + self.caa_alpha * steering_vector.unsqueeze(
|
143 |
-
|
|
|
|
|
144 |
# Return modified output
|
145 |
if isinstance(output, tuple):
|
146 |
return (hidden_states,) + output[1:]
|
147 |
else:
|
148 |
return hidden_states
|
149 |
-
|
150 |
def forward(
|
151 |
self,
|
152 |
input_ids: torch.LongTensor = None,
|
@@ -163,16 +169,16 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
163 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
164 |
"""
|
165 |
Forward pass with automatic CAA steering application.
|
166 |
-
|
167 |
The steering is applied via a forward hook on the specified layer,
|
168 |
following the pattern from optuna_pipeline.py.
|
169 |
"""
|
170 |
-
|
171 |
# Register CAA steering hook if enabled and not already registered
|
172 |
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
|
173 |
target_layer = self.model.layers[self.caa_layer_id]
|
174 |
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
|
175 |
-
|
176 |
# Call parent forward method
|
177 |
outputs = super().forward(
|
178 |
input_ids=input_ids,
|
@@ -185,22 +191,22 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
185 |
output_attentions=output_attentions,
|
186 |
output_hidden_states=output_hidden_states,
|
187 |
return_dict=return_dict,
|
188 |
-
cache_position=cache_position if hasattr(self,
|
189 |
)
|
190 |
-
|
191 |
return outputs
|
192 |
-
|
193 |
def generate(self, *args, **kwargs):
|
194 |
"""
|
195 |
Generate method with automatic CAA steering.
|
196 |
-
|
197 |
The steering hook is registered before generation and cleaned up after.
|
198 |
"""
|
199 |
# Register hook if needed
|
200 |
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
|
201 |
target_layer = self.model.layers[self.caa_layer_id]
|
202 |
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
|
203 |
-
|
204 |
try:
|
205 |
# Call parent generate method
|
206 |
outputs = super().generate(*args, **kwargs)
|
@@ -209,65 +215,71 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
209 |
if self._steering_hook_handle is not None:
|
210 |
self._steering_hook_handle.remove()
|
211 |
self._steering_hook_handle = None
|
212 |
-
|
213 |
return outputs
|
214 |
-
|
215 |
def set_caa_enabled(self, enabled: bool):
|
216 |
"""Enable or disable CAA steering at runtime."""
|
217 |
self.caa_enabled = enabled
|
218 |
if not enabled and self._steering_hook_handle is not None:
|
219 |
self._steering_hook_handle.remove()
|
220 |
self._steering_hook_handle = None
|
221 |
-
|
222 |
def set_caa_alpha(self, alpha: float):
|
223 |
"""Adjust CAA steering strength at runtime."""
|
224 |
self.caa_alpha = alpha
|
225 |
-
|
226 |
@classmethod
|
227 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
228 |
"""
|
229 |
Load model with automatic CAA configuration.
|
230 |
-
|
231 |
This method ensures the steering vector is loaded from the embedded config.
|
232 |
If no weights are found locally, it loads from the base Qwen model.
|
233 |
"""
|
234 |
import os
|
235 |
from pathlib import Path
|
236 |
-
|
237 |
# Check if we have local weights
|
238 |
local_path = Path(pretrained_model_name_or_path)
|
239 |
has_weights = any(
|
240 |
-
(local_path / f).exists()
|
241 |
-
for f in [
|
|
|
|
|
|
|
|
|
|
|
242 |
)
|
243 |
-
|
244 |
if not has_weights and local_path.exists() and (local_path / "config.json").exists():
|
245 |
# We have config but no weights - load from base model
|
246 |
print(f"Loading weights from base model: Qwen/Qwen2.5-Coder-7B-Instruct")
|
247 |
-
|
248 |
# First, load config from local path
|
249 |
from transformers import AutoConfig
|
|
|
250 |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
251 |
-
|
252 |
# Load model with base weights
|
253 |
# Remove config from kwargs if it exists to avoid conflict
|
254 |
kwargs_copy = kwargs.copy()
|
255 |
-
kwargs_copy.pop(
|
256 |
-
|
257 |
model = super().from_pretrained(
|
258 |
"Qwen/Qwen2.5-Coder-7B-Instruct",
|
259 |
*model_args,
|
260 |
config=config, # Use our custom config
|
261 |
-
**kwargs_copy
|
262 |
)
|
263 |
-
|
264 |
# Initialize CAA components
|
265 |
model.caa_enabled = config.caa_enabled
|
266 |
model.caa_layer_id = config.caa_layer_id
|
267 |
model.caa_alpha = config.caa_alpha
|
268 |
model.steering_method = config.steering_method
|
269 |
model._steering_hook_handle = None
|
270 |
-
|
271 |
# Load steering vector from config
|
272 |
if model.caa_enabled:
|
273 |
vector_path = config.steering_vector_path
|
@@ -277,14 +289,14 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
277 |
else:
|
278 |
# Standard loading path
|
279 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
280 |
-
|
281 |
# Load steering vector from config if not already loaded
|
282 |
if model.caa_enabled and model.steering_vector is None:
|
283 |
vector_path = model.config.steering_vector_path
|
284 |
if not os.path.isabs(vector_path):
|
285 |
vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
|
286 |
model._load_steering_vector_from_file(vector_path)
|
287 |
-
|
288 |
return model
|
289 |
|
290 |
|
@@ -292,4 +304,4 @@ class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
|
292 |
from transformers import AutoModelForCausalLM, AutoConfig
|
293 |
|
294 |
AutoConfig.register("wisent_qwen2", WisentQwen2Config)
|
295 |
-
AutoModelForCausalLM.register(WisentQwen2Config, WisentQwen2ForCausalLM)
|
|
|
15 |
|
16 |
class WisentQwen2Config(Qwen2Config):
|
17 |
"""Extended Qwen2 configuration with CAA steering parameters."""
|
18 |
+
|
19 |
model_type = "wisent_qwen2"
|
20 |
+
|
21 |
def __init__(
|
22 |
self,
|
23 |
caa_enabled: bool = True,
|
|
|
25 |
caa_alpha: float = 0.9,
|
26 |
steering_vector_path: str = "./vectors/coding/steering_vector.safetensors",
|
27 |
steering_method: str = "caa",
|
28 |
+
**kwargs,
|
29 |
):
|
30 |
super().__init__(**kwargs)
|
31 |
self.caa_enabled = caa_enabled
|
|
|
38 |
class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
|
39 |
"""
|
40 |
Qwen2 model with integrated CAA steering for improved code generation.
|
41 |
+
|
42 |
This model automatically applies Contrastive Activation Addition (CAA) steering
|
43 |
during the forward pass, eliminating the need for manual hook management.
|
44 |
"""
|
45 |
+
|
46 |
config_class = WisentQwen2Config
|
47 |
+
|
48 |
def __init__(self, config: WisentQwen2Config):
|
49 |
super().__init__(config)
|
50 |
+
|
51 |
# CAA steering parameters
|
52 |
self.caa_enabled = config.caa_enabled
|
53 |
self.caa_layer_id = config.caa_layer_id
|
54 |
self.caa_alpha = config.caa_alpha
|
55 |
self.steering_method = config.steering_method
|
56 |
+
|
57 |
# Load steering vector from file
|
58 |
self.steering_vector = None
|
59 |
if self.caa_enabled:
|
60 |
self._load_steering_vector_from_file(config.steering_vector_path)
|
61 |
+
|
62 |
# Hook handle for cleanup
|
63 |
self._steering_hook_handle = None
|
64 |
+
|
65 |
def _load_steering_vector_from_file(self, path: str):
|
66 |
"""Load the CAA steering vector from safetensors or pytorch file."""
|
67 |
import os
|
68 |
+
|
69 |
try:
|
70 |
# Try relative path first
|
71 |
if os.path.exists(path):
|
|
|
77 |
print(f"Warning: Steering vector not found at {path}, CAA disabled")
|
78 |
self.caa_enabled = False
|
79 |
return
|
80 |
+
|
81 |
# Load based on file extension
|
82 |
+
if vector_path.endswith(".safetensors"):
|
83 |
# Load from safetensors format (preferred)
|
84 |
try:
|
85 |
from safetensors.torch import load_file
|
86 |
+
|
87 |
steering_data = load_file(vector_path)
|
88 |
+
self.steering_vector = steering_data["steering_vector"]
|
89 |
except ImportError:
|
90 |
print("Warning: safetensors not installed, install with: pip install safetensors")
|
91 |
self.caa_enabled = False
|
92 |
return
|
93 |
else:
|
94 |
# Load from pytorch format (fallback)
|
95 |
+
steering_data = torch.load(vector_path, map_location="cpu")
|
96 |
+
|
97 |
# Handle different storage formats
|
98 |
if isinstance(steering_data, dict):
|
99 |
+
if "vector" in steering_data:
|
100 |
+
self.steering_vector = steering_data["vector"]
|
101 |
+
elif "steering_vector" in steering_data:
|
102 |
+
self.steering_vector = steering_data["steering_vector"]
|
103 |
else:
|
104 |
# Assume the dict values are the vectors
|
105 |
self.steering_vector = next(iter(steering_data.values()))
|
106 |
else:
|
107 |
self.steering_vector = steering_data
|
108 |
+
|
109 |
# Ensure it's a tensor
|
110 |
if not isinstance(self.steering_vector, torch.Tensor):
|
111 |
self.steering_vector = torch.tensor(self.steering_vector)
|
112 |
+
|
113 |
+
print(
|
114 |
+
f"✅ Loaded CAA steering vector from {vector_path}: shape {self.steering_vector.shape}, norm {torch.norm(self.steering_vector).item():.4f}"
|
115 |
+
)
|
116 |
+
|
117 |
except Exception as e:
|
118 |
print(f"Warning: Failed to load steering vector: {e}, CAA disabled")
|
119 |
self.caa_enabled = False
|
120 |
self.steering_vector = None
|
121 |
+
|
122 |
def _apply_caa_steering(self, module, input, output):
|
123 |
"""
|
124 |
Hook function that applies CAA steering to the specified layer.
|
125 |
+
|
126 |
This follows the implementation from wisent_guard/core/steering_methods/caa.py
|
127 |
and the patterns from wisent_guard/core/optuna/optuna_pipeline.py
|
128 |
"""
|
129 |
if not self.caa_enabled or self.steering_vector is None:
|
130 |
return output
|
131 |
+
|
132 |
# Extract hidden states from output
|
133 |
if isinstance(output, tuple):
|
134 |
hidden_states = output[0]
|
135 |
else:
|
136 |
hidden_states = output
|
137 |
+
|
138 |
# Apply steering to the last token position (standard CAA behavior)
|
139 |
# This matches the implementation in optuna_pipeline.py lines 744-746
|
140 |
if hidden_states.dim() == 3: # [batch, seq, hidden]
|
141 |
# Move steering vector to the same device and dtype
|
142 |
steering_vector = self.steering_vector.to(hidden_states.device, hidden_states.dtype)
|
143 |
+
|
144 |
# Apply steering with configured alpha (strength)
|
145 |
# Steering is applied to the last token position
|
146 |
+
hidden_states[:, -1:, :] = hidden_states[:, -1:, :] + self.caa_alpha * steering_vector.unsqueeze(
|
147 |
+
0
|
148 |
+
).unsqueeze(0)
|
149 |
+
|
150 |
# Return modified output
|
151 |
if isinstance(output, tuple):
|
152 |
return (hidden_states,) + output[1:]
|
153 |
else:
|
154 |
return hidden_states
|
155 |
+
|
156 |
def forward(
|
157 |
self,
|
158 |
input_ids: torch.LongTensor = None,
|
|
|
169 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
170 |
"""
|
171 |
Forward pass with automatic CAA steering application.
|
172 |
+
|
173 |
The steering is applied via a forward hook on the specified layer,
|
174 |
following the pattern from optuna_pipeline.py.
|
175 |
"""
|
176 |
+
|
177 |
# Register CAA steering hook if enabled and not already registered
|
178 |
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
|
179 |
target_layer = self.model.layers[self.caa_layer_id]
|
180 |
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
|
181 |
+
|
182 |
# Call parent forward method
|
183 |
outputs = super().forward(
|
184 |
input_ids=input_ids,
|
|
|
191 |
output_attentions=output_attentions,
|
192 |
output_hidden_states=output_hidden_states,
|
193 |
return_dict=return_dict,
|
194 |
+
cache_position=cache_position if hasattr(self, "cache_position") else None,
|
195 |
)
|
196 |
+
|
197 |
return outputs
|
198 |
+
|
199 |
def generate(self, *args, **kwargs):
|
200 |
"""
|
201 |
Generate method with automatic CAA steering.
|
202 |
+
|
203 |
The steering hook is registered before generation and cleaned up after.
|
204 |
"""
|
205 |
# Register hook if needed
|
206 |
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
|
207 |
target_layer = self.model.layers[self.caa_layer_id]
|
208 |
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
|
209 |
+
|
210 |
try:
|
211 |
# Call parent generate method
|
212 |
outputs = super().generate(*args, **kwargs)
|
|
|
215 |
if self._steering_hook_handle is not None:
|
216 |
self._steering_hook_handle.remove()
|
217 |
self._steering_hook_handle = None
|
218 |
+
|
219 |
return outputs
|
220 |
+
|
221 |
def set_caa_enabled(self, enabled: bool):
|
222 |
"""Enable or disable CAA steering at runtime."""
|
223 |
self.caa_enabled = enabled
|
224 |
if not enabled and self._steering_hook_handle is not None:
|
225 |
self._steering_hook_handle.remove()
|
226 |
self._steering_hook_handle = None
|
227 |
+
|
228 |
def set_caa_alpha(self, alpha: float):
|
229 |
"""Adjust CAA steering strength at runtime."""
|
230 |
self.caa_alpha = alpha
|
231 |
+
|
232 |
@classmethod
|
233 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
234 |
"""
|
235 |
Load model with automatic CAA configuration.
|
236 |
+
|
237 |
This method ensures the steering vector is loaded from the embedded config.
|
238 |
If no weights are found locally, it loads from the base Qwen model.
|
239 |
"""
|
240 |
import os
|
241 |
from pathlib import Path
|
242 |
+
|
243 |
# Check if we have local weights
|
244 |
local_path = Path(pretrained_model_name_or_path)
|
245 |
has_weights = any(
|
246 |
+
(local_path / f).exists()
|
247 |
+
for f in [
|
248 |
+
"pytorch_model.bin",
|
249 |
+
"model.safetensors",
|
250 |
+
"pytorch_model.bin.index.json",
|
251 |
+
"model.safetensors.index.json",
|
252 |
+
]
|
253 |
)
|
254 |
+
|
255 |
if not has_weights and local_path.exists() and (local_path / "config.json").exists():
|
256 |
# We have config but no weights - load from base model
|
257 |
print(f"Loading weights from base model: Qwen/Qwen2.5-Coder-7B-Instruct")
|
258 |
+
|
259 |
# First, load config from local path
|
260 |
from transformers import AutoConfig
|
261 |
+
|
262 |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
263 |
+
|
264 |
# Load model with base weights
|
265 |
# Remove config from kwargs if it exists to avoid conflict
|
266 |
kwargs_copy = kwargs.copy()
|
267 |
+
kwargs_copy.pop("config", None)
|
268 |
+
|
269 |
model = super().from_pretrained(
|
270 |
"Qwen/Qwen2.5-Coder-7B-Instruct",
|
271 |
*model_args,
|
272 |
config=config, # Use our custom config
|
273 |
+
**kwargs_copy,
|
274 |
)
|
275 |
+
|
276 |
# Initialize CAA components
|
277 |
model.caa_enabled = config.caa_enabled
|
278 |
model.caa_layer_id = config.caa_layer_id
|
279 |
model.caa_alpha = config.caa_alpha
|
280 |
model.steering_method = config.steering_method
|
281 |
model._steering_hook_handle = None
|
282 |
+
|
283 |
# Load steering vector from config
|
284 |
if model.caa_enabled:
|
285 |
vector_path = config.steering_vector_path
|
|
|
289 |
else:
|
290 |
# Standard loading path
|
291 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
292 |
+
|
293 |
# Load steering vector from config if not already loaded
|
294 |
if model.caa_enabled and model.steering_vector is None:
|
295 |
vector_path = model.config.steering_vector_path
|
296 |
if not os.path.isabs(vector_path):
|
297 |
vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
|
298 |
model._load_steering_vector_from_file(vector_path)
|
299 |
+
|
300 |
return model
|
301 |
|
302 |
|
|
|
304 |
from transformers import AutoModelForCausalLM, AutoConfig
|
305 |
|
306 |
AutoConfig.register("wisent_qwen2", WisentQwen2Config)
|
307 |
+
AutoModelForCausalLM.register(WisentQwen2Config, WisentQwen2ForCausalLM)
|
vectors/mbpp_plus/steering_vector.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2e8bc7bbdbdee38910662c28ca924d5e2432a14a873ada066d4eab4db041235
|
3 |
+
size 7256
|