""" custom_modeling.py – model-agnostic toxicity wrapper ---------------------------------------------------- Place in repo root together with: • toxic.keras Add to config.json: "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" } """ import importlib from functools import lru_cache import torch import transformers import tensorflow as tf from huggingface_hub import hf_hub_download # ------------------------------------------------------------------ # # 1) MIXIN – toxicity filtering logic # # ------------------------------------------------------------------ # class _SafeGenerationMixin: _toxicity_model = None _tox_threshold = 0.6 # Separate messages _safe_in_msg = "Sorry, I can’t help with that request." _safe_out_msg = "I’m sorry, but I can’t continue with that." _tokenizer = None # ---- helpers ---------------------------------------------------- def _device(self): return next(self.parameters()).device @property def _tox_model(self): if self._toxicity_model is None: path = hf_hub_download( repo_id=self.config.name_or_path, filename="toxic.keras", ) self._toxicity_model = tf.keras.models.load_model(path, compile=False) return self._toxicity_model def _ensure_tokenizer(self): if self._tokenizer is None: try: self._tokenizer = transformers.AutoTokenizer.from_pretrained( self.config.name_or_path, trust_remote_code=True ) except Exception: pass def _is_toxic(self, text: str) -> bool: if not text.strip(): return False inputs = tf.constant([text], dtype=tf.string) prob = float(self._tox_model.predict(inputs)[0, 0]) return prob >= self._tox_threshold def _safe_ids(self, message: str, length: int | None = None): """Encode *message* and pad/truncate to *length* tokens (if given).""" self._ensure_tokenizer() if self._tokenizer is None: raise RuntimeError("Tokenizer unavailable for safe-message encoding.") ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0] if length is not None: pad_id = ( self.config.eos_token_id if self.config.eos_token_id is not None else (self.config.pad_token_id or 0) ) if ids.size(0) < length: ids = torch.cat( [ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0 ) else: ids = ids[:length] return ids.to(self._device()) # ---- main override --------------------------------------------- def generate(self, *args, **kwargs): self._ensure_tokenizer() # 1) prompt toxicity prompt_txt = None if self._tokenizer is not None: if "input_ids" in kwargs: prompt_txt = self._tokenizer.decode( kwargs["input_ids"][0].tolist(), skip_special_tokens=True ) elif args: prompt_txt = self._tokenizer.decode( args[0][0].tolist(), skip_special_tokens=True ) if prompt_txt and self._is_toxic(prompt_txt): return self._safe_ids(self._safe_in_msg).unsqueeze(0) # 2) normal generation outputs = super().generate(*args, **kwargs) # 3) output toxicity if self._tokenizer is None: return outputs new_seqs = [] for seq in outputs.detach().cpu(): txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True) if self._is_toxic(txt): new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0))) else: new_seqs.append(seq) return torch.stack(new_seqs, dim=0).to(self._device()) # ------------------------------------------------------------------ # # 2) utilities: resolve base class & cache subclass # # ------------------------------------------------------------------ # @lru_cache(None) def _get_base_cls(arch: str): if hasattr(transformers, arch): return getattr(transformers, arch) stem = arch.replace("ForCausalLM", "").lower() module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}") return getattr(module, arch) @lru_cache(None) def _make_safe_subclass(base_cls): return type( f"SafeGeneration_{base_cls.__name__}", (_SafeGenerationMixin, base_cls), {}, ) # ------------------------------------------------------------------ # # 3) Dispatcher class – referenced by auto_map # # ------------------------------------------------------------------ # class SafeGenerationModel: @classmethod def from_pretrained(cls, repo_id, *model_args, **kwargs): kwargs.setdefault("trust_remote_code", True) if kwargs.get("torch_dtype") == "auto": kwargs.pop("torch_dtype") config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs) if not getattr(config, "architectures", None): raise ValueError("`config.architectures` missing in config.json.") arch_str = config.architectures[0] Base = _get_base_cls(arch_str) Safe = _make_safe_subclass(Base) kwargs.pop("config", None) # avoid duplicate return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs)