hirlimann commited on
Commit
ffbcc72
·
verified ·
1 Parent(s): b826cf4

Upload featurizers.py

Browse files
Files changed (1) hide show
  1. featurizers.py +495 -0
featurizers.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ featurizers.py
3
+ ==============
4
+ Utility classes for defining *invertible* feature spaces on top of a model’s
5
+ hidden-state tensors, together with intervention helpers that operate inside
6
+ those spaces.
7
+
8
+ Key ideas
9
+ ---------
10
+
11
+ * **Featurizer** – a lightweight wrapper holding:
12
+ • a forward `featurizer` module that maps a tensor **x → (f, error)**
13
+ where *error* is the reconstruction residual (useful for lossy
14
+ featurizers such as sparse auto-encoders);
15
+ • an `inverse_featurizer` that re-assembles the original space
16
+ **(f, error) → x̂**.
17
+
18
+ * **Interventions** – three higher-order factory functions build PyVENE
19
+ interventions that work in the featurized space:
20
+ - *interchange*
21
+ - *collect*
22
+ - *mask* (differential binary masking)
23
+
24
+ All public classes / functions below carry PEP-257-style doc-strings.
25
+ """
26
+
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+ import pyvene as pv
31
+
32
+
33
+ # --------------------------------------------------------------------------- #
34
+ # Basic identity featurizers #
35
+ # --------------------------------------------------------------------------- #
36
+ class IdentityFeaturizerModule(torch.nn.Module):
37
+ """A no-op featurizer: *x → (x, None)*."""
38
+
39
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
40
+ return x, None
41
+
42
+
43
+ class IdentityInverseFeaturizerModule(torch.nn.Module):
44
+ """Inverse of :class:`IdentityFeaturizerModule`."""
45
+
46
+ def forward(self, x: torch.Tensor, error: None) -> torch.Tensor: # noqa: D401
47
+ return x
48
+
49
+
50
+ # --------------------------------------------------------------------------- #
51
+ # High-level Featurizer wrapper #
52
+ # --------------------------------------------------------------------------- #
53
+ class Featurizer:
54
+ """Container object holding paired featurizer and inverse modules.
55
+
56
+ Parameters
57
+ ----------
58
+ featurizer :
59
+ A `torch.nn.Module` mapping **x → (features, error)**.
60
+ inverse_featurizer :
61
+ A `torch.nn.Module` mapping **(features, error) → x̂**.
62
+ n_features :
63
+ Dimensionality of the feature space. **Required** when you intend to
64
+ build a *mask* intervention; optional otherwise.
65
+ id :
66
+ Human-readable identifier used by `__str__` methods of the generated
67
+ interventions.
68
+ """
69
+
70
+ # --------------------------------------------------------------------- #
71
+ # Construction / public accessors #
72
+ # --------------------------------------------------------------------- #
73
+ def __init__(
74
+ self,
75
+ featurizer: torch.nn.Module = IdentityFeaturizerModule(),
76
+ inverse_featurizer: torch.nn.Module = IdentityInverseFeaturizerModule(),
77
+ *,
78
+ n_features: Optional[int] = None,
79
+ id: str = "null",
80
+ ):
81
+ self.featurizer = featurizer
82
+ self.inverse_featurizer = inverse_featurizer
83
+ self.n_features = n_features
84
+ self.id = id
85
+
86
+ # -------------------- Intervention builders -------------------------- #
87
+ def get_interchange_intervention(self):
88
+ if not hasattr(self, "_interchange_intervention"):
89
+ self._interchange_intervention = build_feature_interchange_intervention(
90
+ self.featurizer, self.inverse_featurizer, self.id
91
+ )
92
+ return self._interchange_intervention
93
+
94
+ def get_collect_intervention(self):
95
+ if not hasattr(self, "_collect_intervention"):
96
+ self._collect_intervention = build_feature_collect_intervention(
97
+ self.featurizer, self.id
98
+ )
99
+ return self._collect_intervention
100
+
101
+ def get_mask_intervention(self):
102
+ if self.n_features is None:
103
+ raise ValueError(
104
+ "`n_features` must be provided on the Featurizer "
105
+ "to construct a mask intervention."
106
+ )
107
+ if not hasattr(self, "_mask_intervention"):
108
+ self._mask_intervention = build_feature_mask_intervention(
109
+ self.featurizer,
110
+ self.inverse_featurizer,
111
+ self.n_features,
112
+ self.id,
113
+ )
114
+ return self._mask_intervention
115
+
116
+ # ------------------------- Convenience I/O --------------------------- #
117
+ def featurize(self, x: torch.Tensor):
118
+ return self.featurizer(x)
119
+
120
+ def inverse_featurize(self, x: torch.Tensor, error):
121
+ return self.inverse_featurizer(x, error)
122
+
123
+ # --------------------------------------------------------------------- #
124
+ # (De)serialisation helpers #
125
+ # --------------------------------------------------------------------- #
126
+ def save_modules(self, path: str) -> Tuple[str, str]:
127
+ """Serialise featurizer & inverse to `<path>_{featurizer, inverse}`.
128
+
129
+ Notes
130
+ -----
131
+ * **SAE featurizers** are *not* serialisable: a
132
+ :class:`NotImplementedError` is raised.
133
+ * Existing files will be *silently overwritten*.
134
+ """
135
+ featurizer_class = self.featurizer.__class__.__name__
136
+
137
+ if featurizer_class == "SAEFeaturizerModule":
138
+ #SAE featurizers are to be loaded from sae_lens
139
+ return None, None
140
+
141
+ inverse_featurizer_class = self.inverse_featurizer.__class__.__name__
142
+
143
+ # Extra config needed for Subspace featurizers
144
+ additional_config = {}
145
+ if featurizer_class == "SubspaceFeaturizerModule":
146
+ additional_config["rotation_matrix"] = (
147
+ self.featurizer.rotate.weight.detach().clone()
148
+ )
149
+ additional_config["requires_grad"] = (
150
+ self.featurizer.rotate.weight.requires_grad
151
+ )
152
+
153
+ model_info = {
154
+ "featurizer_class": featurizer_class,
155
+ "inverse_featurizer_class": inverse_featurizer_class,
156
+ "n_features": self.n_features,
157
+ "additional_config": additional_config,
158
+ }
159
+
160
+ torch.save(
161
+ {"model_info": model_info, "state_dict": self.featurizer.state_dict()},
162
+ f"{path}_featurizer",
163
+ )
164
+ torch.save(
165
+ {
166
+ "model_info": model_info,
167
+ "state_dict": self.inverse_featurizer.state_dict(),
168
+ },
169
+ f"{path}_inverse_featurizer",
170
+ )
171
+ return f"{path}_featurizer", f"{path}_inverse_featurizer"
172
+
173
+ @classmethod
174
+ def load_modules(cls, path: str) -> "Featurizer":
175
+ """Inverse of :meth:`save_modules`.
176
+
177
+ Returns
178
+ -------
179
+ Featurizer
180
+ A *new* instance with reconstructed modules and metadata.
181
+ """
182
+ featurizer_data = torch.load(f"{path}_featurizer")
183
+ inverse_data = torch.load(f"{path}_inverse_featurizer")
184
+
185
+ model_info = featurizer_data["model_info"]
186
+ featurizer_class = model_info["featurizer_class"]
187
+
188
+ if featurizer_class == "SubspaceFeaturizerModule":
189
+ rot = model_info["additional_config"]["rotation_matrix"]
190
+ requires_grad = model_info["additional_config"]["requires_grad"]
191
+
192
+ # Re-build a parametrised orthogonal layer with identical shape.
193
+ in_dim, out_dim = rot.shape
194
+ rotate_layer = pv.models.layers.LowRankRotateLayer(
195
+ in_dim, out_dim, init_orth=False
196
+ )
197
+ rotate_layer.weight.data.copy_(rot)
198
+ rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
199
+ rotate_layer.requires_grad_(requires_grad)
200
+
201
+ featurizer = SubspaceFeaturizerModule(rotate_layer)
202
+ inverse = SubspaceInverseFeaturizerModule(rotate_layer)
203
+
204
+ # Sanity-check weight shape
205
+ assert (
206
+ featurizer.rotate.weight.shape == rot.shape
207
+ ), "Rotation-matrix shape mismatch after deserialisation."
208
+ elif featurizer_class == "IdentityFeaturizerModule":
209
+ featurizer = IdentityFeaturizerModule()
210
+ inverse = IdentityInverseFeaturizerModule()
211
+ else:
212
+ raise ValueError(f"Unknown featurizer class '{featurizer_class}'.")
213
+
214
+ featurizer.load_state_dict(featurizer_data["state_dict"])
215
+ inverse.load_state_dict(inverse_data["state_dict"])
216
+
217
+ return cls(
218
+ featurizer,
219
+ inverse,
220
+ n_features=model_info["n_features"],
221
+ id=model_info.get("featurizer_id", "loaded"),
222
+ )
223
+
224
+
225
+ # --------------------------------------------------------------------------- #
226
+ # Intervention factory helpers #
227
+ # --------------------------------------------------------------------------- #
228
+ def build_feature_interchange_intervention(
229
+ featurizer: torch.nn.Module,
230
+ inverse_featurizer: torch.nn.Module,
231
+ featurizer_id: str,
232
+ ):
233
+ """Return a class implementing PyVENE’s TrainableIntervention."""
234
+
235
+ class FeatureInterchangeIntervention(
236
+ pv.TrainableIntervention, pv.DistributedRepresentationIntervention
237
+ ):
238
+ """Swap features between *base* and *source* in the featurized space."""
239
+
240
+ def __init__(self, **kwargs):
241
+ super().__init__(**kwargs)
242
+ self._featurizer = featurizer
243
+ self._inverse = inverse_featurizer
244
+
245
+ def forward(self, base, source, subspaces=None):
246
+ f_base, base_err = self._featurizer(base)
247
+ f_src, _ = self._featurizer(source)
248
+
249
+ if subspaces is None or _subspace_is_all_none(subspaces):
250
+ f_out = f_src
251
+ else:
252
+ f_out = pv.models.intervention_utils._do_intervention_by_swap(
253
+ f_base,
254
+ f_src,
255
+ "interchange",
256
+ self.interchange_dim,
257
+ subspaces,
258
+ subspace_partition=self.subspace_partition,
259
+ use_fast=self.use_fast,
260
+ )
261
+ return self._inverse(f_out, base_err).to(base.dtype)
262
+
263
+ def __str__(self): # noqa: D401
264
+ return f"FeatureInterchangeIntervention(id={featurizer_id})"
265
+
266
+ return FeatureInterchangeIntervention
267
+
268
+
269
+ def build_feature_collect_intervention(
270
+ featurizer: torch.nn.Module, featurizer_id: str
271
+ ):
272
+ """Return a `CollectIntervention` operating in feature space."""
273
+
274
+ class FeatureCollectIntervention(pv.CollectIntervention):
275
+ def __init__(self, **kwargs):
276
+ super().__init__(**kwargs)
277
+ self._featurizer = featurizer
278
+
279
+ def forward(self, base, source=None, subspaces=None):
280
+ f_base, _ = self._featurizer(base)
281
+ return pv.models.intervention_utils._do_intervention_by_swap(
282
+ f_base,
283
+ source,
284
+ "collect",
285
+ self.interchange_dim,
286
+ subspaces,
287
+ subspace_partition=self.subspace_partition,
288
+ use_fast=self.use_fast,
289
+ )
290
+
291
+ def __str__(self): # noqa: D401
292
+ return f"FeatureCollectIntervention(id={featurizer_id})"
293
+
294
+ return FeatureCollectIntervention
295
+
296
+
297
+ def build_feature_mask_intervention(
298
+ featurizer: torch.nn.Module,
299
+ inverse_featurizer: torch.nn.Module,
300
+ n_features: int,
301
+ featurizer_id: str,
302
+ bottleneck_ratio: int = 2, # optional argument for controlling capacity
303
+ ):
304
+ """Return a trainable mask intervention with optional training-time nonlinearity."""
305
+
306
+ class FeatureMaskIntervention(pv.TrainableIntervention):
307
+ """Differential-binary masking in the featurized space, with MLP nonlinearity during training."""
308
+
309
+ def __init__(self, **kwargs):
310
+ super().__init__(**kwargs)
311
+ self._featurizer = featurizer
312
+ self._inverse = inverse_featurizer
313
+
314
+ # Learnable parameters
315
+ self.mask = torch.nn.Parameter(torch.zeros(n_features), requires_grad=True)
316
+ self.temperature: Optional[torch.Tensor] = None # must be set by user
317
+
318
+ # === Optional non-linear transform ===
319
+ bottleneck_dim = bottleneck_ratio * n_features
320
+ self._nonlinear_up = torch.nn.Linear(n_features, bottleneck_dim)
321
+ self._nonlinear_act = torch.nn.GELU()
322
+ self._nonlinear_down = torch.nn.Linear(bottleneck_dim, n_features)
323
+ self._post_act = torch.nn.tanh()
324
+ # -------------------- API helpers -------------------- #
325
+ def get_temperature(self) -> torch.Tensor:
326
+ if self.temperature is None:
327
+ raise ValueError("Temperature has not been set.")
328
+ return self.temperature
329
+
330
+ def set_temperature(self, temp: float | torch.Tensor):
331
+ self.temperature = (
332
+ torch.as_tensor(temp, dtype=self.mask.dtype).to(self.mask.device)
333
+ )
334
+
335
+ def _nonlinear_transform(self, f: torch.Tensor) -> torch.Tensor:
336
+ f = self._nonlinear_up(f)
337
+ f = self._nonlinear_act(f)
338
+ f = self._nonlinear_down(f)
339
+ f = self._post_act(f)
340
+ return f
341
+
342
+ # ------------------------- forward ------------------- #
343
+ def forward(self, base, source, subspaces=None):
344
+ if self.temperature is None:
345
+ raise ValueError("Cannot run forward without a temperature.")
346
+
347
+ f_base, base_err = self._featurizer(base)
348
+ f_src, _ = self._featurizer(source)
349
+
350
+ # Align devices / dtypes
351
+ mask = self.mask.to(f_base.device)
352
+ temp = self.temperature.to(f_base.device)
353
+
354
+ f_base = f_base.to(mask.dtype)
355
+ f_src = f_src.to(mask.dtype)
356
+
357
+ if self.training:
358
+ gate = torch.sigmoid(mask / temp)
359
+ else:
360
+ gate = (torch.sigmoid(mask) > 0.5).float()
361
+
362
+ f_out = (1.0 - gate) * f_base + gate * f_src
363
+ if self.training:
364
+ f_out = self._nonlinear_transform(f_out)
365
+ return self._inverse(f_out.to(base.dtype), base_err).to(base.dtype)
366
+
367
+ # ---------------- Sparsity regulariser --------------- #
368
+ def get_sparsity_loss(self) -> torch.Tensor:
369
+ if self.temperature is None:
370
+ raise ValueError("Temperature has not been set.")
371
+ gate = torch.sigmoid(self.mask / self.temperature)
372
+ return torch.norm(gate, p=1)
373
+
374
+ def __str__(self): # noqa: D401
375
+ return f"FeatureMaskIntervention(id={featurizer_id})"
376
+
377
+ return FeatureMaskIntervention
378
+
379
+
380
+ # --------------------------------------------------------------------------- #
381
+ # Concrete featurizer implementations #
382
+ # --------------------------------------------------------------------------- #
383
+ class SubspaceFeaturizerModule(torch.nn.Module):
384
+ """Linear projector onto an orthogonal *rotation* sub-space."""
385
+
386
+ def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
387
+ super().__init__()
388
+ self.rotate = rotate_layer
389
+
390
+ def forward(self, x: torch.Tensor):
391
+ r = self.rotate.weight.T # (out, in)ᵀ
392
+ f = x.to(r.dtype) @ r.T
393
+ error = x - (f @ r).to(x.dtype)
394
+ return f, error
395
+
396
+
397
+ class SubspaceInverseFeaturizerModule(torch.nn.Module):
398
+ """Inverse of :class:`SubspaceFeaturizerModule`."""
399
+
400
+ def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
401
+ super().__init__()
402
+ self.rotate = rotate_layer
403
+
404
+ def forward(self, f, error):
405
+ r = self.rotate.weight.T
406
+ return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
407
+
408
+
409
+ class SubspaceFeaturizer(Featurizer):
410
+ """Orthogonal linear sub-space featurizer."""
411
+
412
+ def __init__(
413
+ self,
414
+ *,
415
+ shape: Tuple[int, int] | None = None,
416
+ rotation_subspace: torch.Tensor | None = None,
417
+ trainable: bool = True,
418
+ id: str = "subspace",
419
+ ):
420
+ assert (
421
+ shape is not None or rotation_subspace is not None
422
+ ), "Provide either `shape` or `rotation_subspace`."
423
+
424
+ if shape is not None:
425
+ rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
426
+ else:
427
+ shape = rotation_subspace.shape
428
+ rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
429
+ rotate.weight.data.copy_(rotation_subspace)
430
+
431
+ rotate = torch.nn.utils.parametrizations.orthogonal(rotate)
432
+ rotate.requires_grad_(trainable)
433
+
434
+ super().__init__(
435
+ SubspaceFeaturizerModule(rotate),
436
+ SubspaceInverseFeaturizerModule(rotate),
437
+ n_features=rotate.weight.shape[1],
438
+ id=id,
439
+ )
440
+
441
+
442
+ class SAEFeaturizerModule(torch.nn.Module):
443
+ """Wrapper around a *Sparse Autoencoder*’s encode() / decode() pair."""
444
+
445
+ def __init__(self, sae):
446
+ super().__init__()
447
+ self.sae = sae
448
+
449
+ def forward(self, x):
450
+ features = self.sae.encode(x.to(self.sae.dtype))
451
+ error = x - self.sae.decode(features).to(x.dtype)
452
+ return features.to(x.dtype), error
453
+
454
+
455
+ class SAEInverseFeaturizerModule(torch.nn.Module):
456
+ """Inverse for :class:`SAEFeaturizerModule`."""
457
+
458
+ def __init__(self, sae):
459
+ super().__init__()
460
+ self.sae = sae
461
+
462
+ def forward(self, features, error):
463
+ return (
464
+ self.sae.decode(features.to(self.sae.dtype)).to(features.dtype)
465
+ + error.to(features.dtype)
466
+ )
467
+
468
+
469
+ class SAEFeaturizer(Featurizer):
470
+ """Featurizer backed by a pre-trained sparse auto-encoder.
471
+
472
+ Notes
473
+ -----
474
+ Serialisation is *disabled* for SAE featurizers – saving will raise
475
+ ``NotImplementedError``.
476
+ """
477
+
478
+ def __init__(self, sae, *, trainable: bool = False):
479
+ sae.requires_grad_(trainable)
480
+ super().__init__(
481
+ SAEFeaturizerModule(sae),
482
+ SAEInverseFeaturizerModule(sae),
483
+ n_features=sae.cfg.to_dict()["d_sae"],
484
+ id="sae",
485
+ )
486
+
487
+
488
+ # --------------------------------------------------------------------------- #
489
+ # Utility helpers #
490
+ # --------------------------------------------------------------------------- #
491
+ def _subspace_is_all_none(subspaces) -> bool:
492
+ """Return ``True`` if *every* element of *subspaces* is ``None``."""
493
+ return subspaces is None or all(
494
+ inner is None or all(elem is None for elem in inner) for inner in subspaces
495
+ )