Mahesh2841 commited on
Commit
c23327e
·
verified ·
1 Parent(s): e278675

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +44 -64
custom_modeling.py CHANGED
@@ -1,13 +1,9 @@
1
  """
2
- custom_modeling.py
3
- ------------------
4
- A single model-agnostic toxicity wrapper for any causal-LM on the Hugging
5
- Face Hub.
6
-
7
- Keep this in the repo **alongside**:
8
- • toxic.keras -- your TF/Keras classifier file
9
-
10
- Make sure config.json contains:
11
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
12
  """
13
 
@@ -21,12 +17,16 @@ from huggingface_hub import hf_hub_download
21
 
22
 
23
  # ------------------------------------------------------------------ #
24
- # 1) MIXIN all toxicity filtering lives here #
25
  # ------------------------------------------------------------------ #
26
  class _SafeGenerationMixin:
27
  _toxicity_model = None
28
- _tox_threshold = 0.6
29
- _safe_message = "Response is toxic, please be kind to yourself and others."
 
 
 
 
30
  _tokenizer = None
31
 
32
  # ---- helpers ----------------------------------------------------
@@ -55,18 +55,17 @@ class _SafeGenerationMixin:
55
  def _is_toxic(self, text: str) -> bool:
56
  if not text.strip():
57
  return False
58
- inputs = tf.constant([text], dtype=tf.string) # <= proper tensor
59
  prob = float(self._tox_model.predict(inputs)[0, 0])
60
  return prob >= self._tox_threshold
61
 
62
- def _safe_ids(self, length: int | None = None) -> torch.LongTensor:
63
- """Return token IDs for the safe message, padded / truncated to *length*."""
64
  self._ensure_tokenizer()
65
  if self._tokenizer is None:
66
  raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
67
 
68
- ids = self._tokenizer(self._safe_message, return_tensors="pt")["input_ids"][0]
69
-
70
  if length is not None:
71
  pad_id = (
72
  self.config.eos_token_id
@@ -79,14 +78,13 @@ class _SafeGenerationMixin:
79
  )
80
  else:
81
  ids = ids[:length]
82
-
83
  return ids.to(self._device())
84
 
85
  # ---- main override ---------------------------------------------
86
  def generate(self, *args, **kwargs):
87
  self._ensure_tokenizer()
88
 
89
- # 1) Prompt toxicity check
90
  prompt_txt = None
91
  if self._tokenizer is not None:
92
  if "input_ids" in kwargs:
@@ -99,81 +97,63 @@ class _SafeGenerationMixin:
99
  )
100
 
101
  if prompt_txt and self._is_toxic(prompt_txt):
102
- return self._safe_ids().unsqueeze(0)
103
 
104
- # 2) Normal generation
105
  outputs = super().generate(*args, **kwargs)
106
 
107
- # 3) Output toxicity check
108
  if self._tokenizer is None:
109
  return outputs
110
 
111
- cleaned_seqs = []
112
  for seq in outputs.detach().cpu():
113
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
114
- cleaned_seqs.append(
115
- self._safe_ids(length=seq.size(0)) if self._is_toxic(txt) else seq
116
- )
117
-
118
- return torch.stack(cleaned_seqs, dim=0).to(self._device())
119
 
120
 
121
  # ------------------------------------------------------------------ #
122
- # 2) Utilities: resolve & cache the real base class #
123
  # ------------------------------------------------------------------ #
124
  @lru_cache(None)
125
- def _get_base_cls(arch_name: str):
126
- """Return the actual Transformers class for the given architecture string."""
127
- if hasattr(transformers, arch_name):
128
- return getattr(transformers, arch_name)
129
-
130
- # Fallback: import based on naming convention
131
- stem = arch_name.replace("ForCausalLM", "").lower()
132
  module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
133
- return getattr(module, arch_name)
134
 
135
 
136
  @lru_cache(None)
137
  def _make_safe_subclass(base_cls):
138
- """Create (and cache) SafeGeneration_<Base> = (Mixin, Base)."""
139
- return type(f"SafeGeneration_{base_cls.__name__}", (_SafeGenerationMixin, base_cls), {})
 
 
 
140
 
141
 
142
  # ------------------------------------------------------------------ #
143
- # 3) Dispatcher class referenced in auto_map #
144
  # ------------------------------------------------------------------ #
145
  class SafeGenerationModel:
146
- """
147
- Lightweight dispatcher so that `AutoModelForCausalLM` can load the
148
- wrapped model transparently.
149
- """
150
-
151
  @classmethod
152
- def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
153
- # Ensure custom code execution is allowed
154
  kwargs.setdefault("trust_remote_code", True)
155
-
156
- # Remove literal "auto" dtype to avoid downstream dtype bugs
157
  if kwargs.get("torch_dtype") == "auto":
158
  kwargs.pop("torch_dtype")
159
 
160
- # Load config first to discover architecture
161
- config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
162
  if not getattr(config, "architectures", None):
163
  raise ValueError("`config.architectures` missing in config.json.")
164
  arch_str = config.architectures[0]
165
 
166
- # Build / fetch dynamic subclass
167
- BaseCLS = _get_base_cls(arch_str)
168
- SafeCLS = _make_safe_subclass(BaseCLS)
169
-
170
- # Avoid duplicate 'config' key
171
- kwargs.pop("config", None)
172
 
173
- # Delegate full loading to the safe subclass
174
- return SafeCLS.from_pretrained(
175
- model_name_or_path,
176
- *model_args,
177
- config=config,
178
- **kwargs,
179
- )
 
1
  """
2
+ custom_modeling.py – model-agnostic toxicity wrapper
3
+ ----------------------------------------------------
4
+ Place in repo root together with:
5
+ toxic.keras
6
+ Add to config.json:
 
 
 
 
7
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
8
  """
9
 
 
17
 
18
 
19
  # ------------------------------------------------------------------ #
20
+ # 1) MIXIN – toxicity filtering logic #
21
  # ------------------------------------------------------------------ #
22
  class _SafeGenerationMixin:
23
  _toxicity_model = None
24
+ _tox_threshold = 0.6
25
+
26
+ # Separate messages
27
+ _safe_in_msg = "Sorry, I can’t help with that request."
28
+ _safe_out_msg = "I’m sorry, but I can’t continue with that."
29
+
30
  _tokenizer = None
31
 
32
  # ---- helpers ----------------------------------------------------
 
55
  def _is_toxic(self, text: str) -> bool:
56
  if not text.strip():
57
  return False
58
+ inputs = tf.constant([text], dtype=tf.string)
59
  prob = float(self._tox_model.predict(inputs)[0, 0])
60
  return prob >= self._tox_threshold
61
 
62
+ def _safe_ids(self, message: str, length: int | None = None):
63
+ """Encode *message* and pad/truncate to *length* tokens (if given)."""
64
  self._ensure_tokenizer()
65
  if self._tokenizer is None:
66
  raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
67
 
68
+ ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
 
69
  if length is not None:
70
  pad_id = (
71
  self.config.eos_token_id
 
78
  )
79
  else:
80
  ids = ids[:length]
 
81
  return ids.to(self._device())
82
 
83
  # ---- main override ---------------------------------------------
84
  def generate(self, *args, **kwargs):
85
  self._ensure_tokenizer()
86
 
87
+ # 1) prompt toxicity
88
  prompt_txt = None
89
  if self._tokenizer is not None:
90
  if "input_ids" in kwargs:
 
97
  )
98
 
99
  if prompt_txt and self._is_toxic(prompt_txt):
100
+ return self._safe_ids(self._safe_in_msg).unsqueeze(0)
101
 
102
+ # 2) normal generation
103
  outputs = super().generate(*args, **kwargs)
104
 
105
+ # 3) output toxicity
106
  if self._tokenizer is None:
107
  return outputs
108
 
109
+ new_seqs = []
110
  for seq in outputs.detach().cpu():
111
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
112
+ if self._is_toxic(txt):
113
+ new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
114
+ else:
115
+ new_seqs.append(seq)
116
+ return torch.stack(new_seqs, dim=0).to(self._device())
117
 
118
 
119
  # ------------------------------------------------------------------ #
120
+ # 2) utilities: resolve base class & cache subclass #
121
  # ------------------------------------------------------------------ #
122
  @lru_cache(None)
123
+ def _get_base_cls(arch: str):
124
+ if hasattr(transformers, arch):
125
+ return getattr(transformers, arch)
126
+ stem = arch.replace("ForCausalLM", "").lower()
 
 
 
127
  module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
128
+ return getattr(module, arch)
129
 
130
 
131
  @lru_cache(None)
132
  def _make_safe_subclass(base_cls):
133
+ return type(
134
+ f"SafeGeneration_{base_cls.__name__}",
135
+ (_SafeGenerationMixin, base_cls),
136
+ {},
137
+ )
138
 
139
 
140
  # ------------------------------------------------------------------ #
141
+ # 3) Dispatcher class referenced by auto_map #
142
  # ------------------------------------------------------------------ #
143
  class SafeGenerationModel:
 
 
 
 
 
144
  @classmethod
145
+ def from_pretrained(cls, repo_id, *model_args, **kwargs):
 
146
  kwargs.setdefault("trust_remote_code", True)
 
 
147
  if kwargs.get("torch_dtype") == "auto":
148
  kwargs.pop("torch_dtype")
149
 
150
+ config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
 
151
  if not getattr(config, "architectures", None):
152
  raise ValueError("`config.architectures` missing in config.json.")
153
  arch_str = config.architectures[0]
154
 
155
+ Base = _get_base_cls(arch_str)
156
+ Safe = _make_safe_subclass(Base)
 
 
 
 
157
 
158
+ kwargs.pop("config", None) # avoid duplicate
159
+ return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs)