777_test / custom_modeling.py
Mahesh2841's picture
Update custom_modeling.py
c23327e verified
"""
custom_modeling.py – model-agnostic toxicity wrapper
----------------------------------------------------
Place in repo root together with:
• toxic.keras
Add to config.json:
"auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
"""
import importlib
from functools import lru_cache
import torch
import transformers
import tensorflow as tf
from huggingface_hub import hf_hub_download
# ------------------------------------------------------------------ #
# 1) MIXIN – toxicity filtering logic #
# ------------------------------------------------------------------ #
class _SafeGenerationMixin:
_toxicity_model = None
_tox_threshold = 0.6
# Separate messages
_safe_in_msg = "Sorry, I can’t help with that request."
_safe_out_msg = "I’m sorry, but I can’t continue with that."
_tokenizer = None
# ---- helpers ----------------------------------------------------
def _device(self):
return next(self.parameters()).device
@property
def _tox_model(self):
if self._toxicity_model is None:
path = hf_hub_download(
repo_id=self.config.name_or_path,
filename="toxic.keras",
)
self._toxicity_model = tf.keras.models.load_model(path, compile=False)
return self._toxicity_model
def _ensure_tokenizer(self):
if self._tokenizer is None:
try:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self.config.name_or_path, trust_remote_code=True
)
except Exception:
pass
def _is_toxic(self, text: str) -> bool:
if not text.strip():
return False
inputs = tf.constant([text], dtype=tf.string)
prob = float(self._tox_model.predict(inputs)[0, 0])
return prob >= self._tox_threshold
def _safe_ids(self, message: str, length: int | None = None):
"""Encode *message* and pad/truncate to *length* tokens (if given)."""
self._ensure_tokenizer()
if self._tokenizer is None:
raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
if length is not None:
pad_id = (
self.config.eos_token_id
if self.config.eos_token_id is not None
else (self.config.pad_token_id or 0)
)
if ids.size(0) < length:
ids = torch.cat(
[ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0
)
else:
ids = ids[:length]
return ids.to(self._device())
# ---- main override ---------------------------------------------
def generate(self, *args, **kwargs):
self._ensure_tokenizer()
# 1) prompt toxicity
prompt_txt = None
if self._tokenizer is not None:
if "input_ids" in kwargs:
prompt_txt = self._tokenizer.decode(
kwargs["input_ids"][0].tolist(), skip_special_tokens=True
)
elif args:
prompt_txt = self._tokenizer.decode(
args[0][0].tolist(), skip_special_tokens=True
)
if prompt_txt and self._is_toxic(prompt_txt):
return self._safe_ids(self._safe_in_msg).unsqueeze(0)
# 2) normal generation
outputs = super().generate(*args, **kwargs)
# 3) output toxicity
if self._tokenizer is None:
return outputs
new_seqs = []
for seq in outputs.detach().cpu():
txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
if self._is_toxic(txt):
new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
else:
new_seqs.append(seq)
return torch.stack(new_seqs, dim=0).to(self._device())
# ------------------------------------------------------------------ #
# 2) utilities: resolve base class & cache subclass #
# ------------------------------------------------------------------ #
@lru_cache(None)
def _get_base_cls(arch: str):
if hasattr(transformers, arch):
return getattr(transformers, arch)
stem = arch.replace("ForCausalLM", "").lower()
module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
return getattr(module, arch)
@lru_cache(None)
def _make_safe_subclass(base_cls):
return type(
f"SafeGeneration_{base_cls.__name__}",
(_SafeGenerationMixin, base_cls),
{},
)
# ------------------------------------------------------------------ #
# 3) Dispatcher class – referenced by auto_map #
# ------------------------------------------------------------------ #
class SafeGenerationModel:
@classmethod
def from_pretrained(cls, repo_id, *model_args, **kwargs):
kwargs.setdefault("trust_remote_code", True)
if kwargs.get("torch_dtype") == "auto":
kwargs.pop("torch_dtype")
config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
if not getattr(config, "architectures", None):
raise ValueError("`config.architectures` missing in config.json.")
arch_str = config.architectures[0]
Base = _get_base_cls(arch_str)
Safe = _make_safe_subclass(Base)
kwargs.pop("config", None) # avoid duplicate
return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs)