|
""" |
|
custom_modeling.py – model-agnostic toxicity wrapper |
|
---------------------------------------------------- |
|
Place in repo root together with: |
|
• toxic.keras |
|
Add to config.json: |
|
"auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" } |
|
""" |
|
|
|
import importlib |
|
from functools import lru_cache |
|
|
|
import torch |
|
import transformers |
|
import tensorflow as tf |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
class _SafeGenerationMixin: |
|
_toxicity_model = None |
|
_tox_threshold = 0.6 |
|
|
|
|
|
_safe_in_msg = "Sorry, I can’t help with that request." |
|
_safe_out_msg = "I’m sorry, but I can’t continue with that." |
|
|
|
_tokenizer = None |
|
|
|
|
|
def _device(self): |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def _tox_model(self): |
|
if self._toxicity_model is None: |
|
path = hf_hub_download( |
|
repo_id=self.config.name_or_path, |
|
filename="toxic.keras", |
|
) |
|
self._toxicity_model = tf.keras.models.load_model(path, compile=False) |
|
return self._toxicity_model |
|
|
|
def _ensure_tokenizer(self): |
|
if self._tokenizer is None: |
|
try: |
|
self._tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
self.config.name_or_path, trust_remote_code=True |
|
) |
|
except Exception: |
|
pass |
|
|
|
def _is_toxic(self, text: str) -> bool: |
|
if not text.strip(): |
|
return False |
|
inputs = tf.constant([text], dtype=tf.string) |
|
prob = float(self._tox_model.predict(inputs)[0, 0]) |
|
return prob >= self._tox_threshold |
|
|
|
def _safe_ids(self, message: str, length: int | None = None): |
|
"""Encode *message* and pad/truncate to *length* tokens (if given).""" |
|
self._ensure_tokenizer() |
|
if self._tokenizer is None: |
|
raise RuntimeError("Tokenizer unavailable for safe-message encoding.") |
|
|
|
ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0] |
|
if length is not None: |
|
pad_id = ( |
|
self.config.eos_token_id |
|
if self.config.eos_token_id is not None |
|
else (self.config.pad_token_id or 0) |
|
) |
|
if ids.size(0) < length: |
|
ids = torch.cat( |
|
[ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0 |
|
) |
|
else: |
|
ids = ids[:length] |
|
return ids.to(self._device()) |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
self._ensure_tokenizer() |
|
|
|
|
|
prompt_txt = None |
|
if self._tokenizer is not None: |
|
if "input_ids" in kwargs: |
|
prompt_txt = self._tokenizer.decode( |
|
kwargs["input_ids"][0].tolist(), skip_special_tokens=True |
|
) |
|
elif args: |
|
prompt_txt = self._tokenizer.decode( |
|
args[0][0].tolist(), skip_special_tokens=True |
|
) |
|
|
|
if prompt_txt and self._is_toxic(prompt_txt): |
|
return self._safe_ids(self._safe_in_msg).unsqueeze(0) |
|
|
|
|
|
outputs = super().generate(*args, **kwargs) |
|
|
|
|
|
if self._tokenizer is None: |
|
return outputs |
|
|
|
new_seqs = [] |
|
for seq in outputs.detach().cpu(): |
|
txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True) |
|
if self._is_toxic(txt): |
|
new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0))) |
|
else: |
|
new_seqs.append(seq) |
|
return torch.stack(new_seqs, dim=0).to(self._device()) |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(None) |
|
def _get_base_cls(arch: str): |
|
if hasattr(transformers, arch): |
|
return getattr(transformers, arch) |
|
stem = arch.replace("ForCausalLM", "").lower() |
|
module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}") |
|
return getattr(module, arch) |
|
|
|
|
|
@lru_cache(None) |
|
def _make_safe_subclass(base_cls): |
|
return type( |
|
f"SafeGeneration_{base_cls.__name__}", |
|
(_SafeGenerationMixin, base_cls), |
|
{}, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class SafeGenerationModel: |
|
@classmethod |
|
def from_pretrained(cls, repo_id, *model_args, **kwargs): |
|
kwargs.setdefault("trust_remote_code", True) |
|
if kwargs.get("torch_dtype") == "auto": |
|
kwargs.pop("torch_dtype") |
|
|
|
config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs) |
|
if not getattr(config, "architectures", None): |
|
raise ValueError("`config.architectures` missing in config.json.") |
|
arch_str = config.architectures[0] |
|
|
|
Base = _get_base_cls(arch_str) |
|
Safe = _make_safe_subclass(Base) |
|
|
|
kwargs.pop("config", None) |
|
return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs) |