qwen2.5-coder-7b-wisent-caa / modeling_wisent_qwen.py
jfpio's picture
Upload folder using huggingface_hub
3bb8d99 verified
"""
Wisent-enhanced Qwen2 model with integrated CAA (Contrastive Activation Addition) steering.
This model automatically applies CAA steering during generation without requiring manual hooks.
The steering parameters are optimized using Optuna and stored in the model configuration.
"""
from typing import List, Optional, Tuple, Union
import torch
from transformers import Qwen2Config, Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
class WisentQwen2Config(Qwen2Config):
"""Extended Qwen2 configuration with CAA steering parameters."""
model_type = "wisent_qwen2"
def __init__(
self,
caa_enabled: bool = True,
caa_layer_id: int = 24,
caa_alpha: float = 0.9,
steering_vector_path: str = "./vectors/coding/steering_vector.safetensors",
steering_method: str = "caa",
**kwargs,
):
super().__init__(**kwargs)
self.caa_enabled = caa_enabled
self.caa_layer_id = caa_layer_id
self.caa_alpha = caa_alpha
self.steering_vector_path = steering_vector_path
self.steering_method = steering_method
class WisentQwen2ForCausalLM(Qwen2ForCausalLM):
"""
Qwen2 model with integrated CAA steering for improved code generation.
This model automatically applies Contrastive Activation Addition (CAA) steering
during the forward pass, eliminating the need for manual hook management.
"""
config_class = WisentQwen2Config
def __init__(self, config: WisentQwen2Config):
super().__init__(config)
# CAA steering parameters
self.caa_enabled = config.caa_enabled
self.caa_layer_id = config.caa_layer_id
self.caa_alpha = config.caa_alpha
self.steering_method = config.steering_method
# Load steering vector from file
self.steering_vector = None
if self.caa_enabled:
self._load_steering_vector_from_file(config.steering_vector_path)
# Hook handle for cleanup
self._steering_hook_handle = None
def _load_steering_vector_from_file(self, path: str):
"""Load the CAA steering vector from safetensors or pytorch file."""
import os
try:
# Try relative path first
if os.path.exists(path):
vector_path = path
# Try path relative to model directory
elif os.path.exists(os.path.join(os.path.dirname(__file__), path)):
vector_path = os.path.join(os.path.dirname(__file__), path)
else:
print(f"Warning: Steering vector not found at {path}, CAA disabled")
self.caa_enabled = False
return
# Load based on file extension
if vector_path.endswith(".safetensors"):
# Load from safetensors format (preferred)
try:
from safetensors.torch import load_file
steering_data = load_file(vector_path)
self.steering_vector = steering_data["steering_vector"]
except ImportError:
print("Warning: safetensors not installed, install with: pip install safetensors")
self.caa_enabled = False
return
else:
# Load from pytorch format (fallback)
steering_data = torch.load(vector_path, map_location="cpu")
# Handle different storage formats
if isinstance(steering_data, dict):
if "vector" in steering_data:
self.steering_vector = steering_data["vector"]
elif "steering_vector" in steering_data:
self.steering_vector = steering_data["steering_vector"]
else:
# Assume the dict values are the vectors
self.steering_vector = next(iter(steering_data.values()))
else:
self.steering_vector = steering_data
# Ensure it's a tensor
if not isinstance(self.steering_vector, torch.Tensor):
self.steering_vector = torch.tensor(self.steering_vector)
print(
f"✅ Loaded CAA steering vector from {vector_path}: shape {self.steering_vector.shape}, norm {torch.norm(self.steering_vector).item():.4f}"
)
except Exception as e:
print(f"Warning: Failed to load steering vector: {e}, CAA disabled")
self.caa_enabled = False
self.steering_vector = None
def _apply_caa_steering(self, module, input, output):
"""
Hook function that applies CAA steering to the specified layer.
This follows the implementation from wisent_guard/core/steering_methods/caa.py
and the patterns from wisent_guard/core/optuna/optuna_pipeline.py
"""
if not self.caa_enabled or self.steering_vector is None:
return output
# Extract hidden states from output
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
# Apply steering to the last token position (standard CAA behavior)
# This matches the implementation in optuna_pipeline.py lines 744-746
if hidden_states.dim() == 3: # [batch, seq, hidden]
# Move steering vector to the same device and dtype
steering_vector = self.steering_vector.to(hidden_states.device, hidden_states.dtype)
# Apply steering with configured alpha (strength)
# Steering is applied to the last token position
hidden_states[:, -1:, :] = hidden_states[:, -1:, :] + self.caa_alpha * steering_vector.unsqueeze(
0
).unsqueeze(0)
# Return modified output
if isinstance(output, tuple):
return (hidden_states,) + output[1:]
return hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with automatic CAA steering application.
The steering is applied via a forward hook on the specified layer,
following the pattern from optuna_pipeline.py.
"""
# Register CAA steering hook if enabled and not already registered
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
target_layer = self.model.layers[self.caa_layer_id]
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
# Call parent forward method
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position if hasattr(self, "cache_position") else None,
)
return outputs
def generate(self, *args, **kwargs):
"""
Generate method with automatic CAA steering.
The steering hook is registered before generation and cleaned up after.
"""
# Register hook if needed
if self.caa_enabled and self.steering_vector is not None and self._steering_hook_handle is None:
target_layer = self.model.layers[self.caa_layer_id]
self._steering_hook_handle = target_layer.register_forward_hook(self._apply_caa_steering)
try:
# Call parent generate method
outputs = super().generate(*args, **kwargs)
finally:
# Clean up hook after generation
if self._steering_hook_handle is not None:
self._steering_hook_handle.remove()
self._steering_hook_handle = None
return outputs
def set_caa_enabled(self, enabled: bool):
"""Enable or disable CAA steering at runtime."""
self.caa_enabled = enabled
if not enabled and self._steering_hook_handle is not None:
self._steering_hook_handle.remove()
self._steering_hook_handle = None
def set_caa_alpha(self, alpha: float):
"""Adjust CAA steering strength at runtime."""
self.caa_alpha = alpha
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load model with automatic CAA configuration.
This method ensures the steering vector is loaded from the embedded config.
If no weights are found locally, it loads from the base Qwen model.
"""
import os
from pathlib import Path
# Check if we have local weights
local_path = Path(pretrained_model_name_or_path)
has_weights = any(
(local_path / f).exists()
for f in [
"pytorch_model.bin",
"model.safetensors",
"pytorch_model.bin.index.json",
"model.safetensors.index.json",
]
)
if not has_weights and local_path.exists() and (local_path / "config.json").exists():
# We have config but no weights - load from base model
print("Loading weights from base model: Qwen/Qwen2.5-Coder-7B-Instruct")
# First, load config from local path
from transformers import AutoConfig
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Load model with base weights
# Remove config from kwargs if it exists to avoid conflict
kwargs_copy = kwargs.copy()
kwargs_copy.pop("config", None)
model = super().from_pretrained(
"Qwen/Qwen2.5-Coder-7B-Instruct",
*model_args,
config=config, # Use our custom config
**kwargs_copy,
)
# Initialize CAA components
model.caa_enabled = config.caa_enabled
model.caa_layer_id = config.caa_layer_id
model.caa_alpha = config.caa_alpha
model.steering_method = config.steering_method
model._steering_hook_handle = None
# Load steering vector from config
if model.caa_enabled:
vector_path = config.steering_vector_path
if not os.path.isabs(vector_path):
vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
model._load_steering_vector_from_file(vector_path)
else:
# Standard loading path
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Load steering vector from config if not already loaded
if model.caa_enabled and model.steering_vector is None:
vector_path = model.config.steering_vector_path
if not os.path.isabs(vector_path):
vector_path = os.path.join(pretrained_model_name_or_path, vector_path)
model._load_steering_vector_from_file(vector_path)
return model
# Register the model
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("wisent_qwen2", WisentQwen2Config)
AutoModelForCausalLM.register(WisentQwen2Config, WisentQwen2ForCausalLM)