AndrewMayesPrezzee commited on
Commit
5b68b61
·
1 Parent(s): f87cb91

Feat - Meta Data Added

Browse files
Files changed (3) hide show
  1. README.md +49 -23
  2. configuration_autoencoder.py +25 -3
  3. modeling_autoencoder.py +359 -21
README.md CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Autoencoder Implementation for Hugging Face Transformers
2
 
3
  A complete autoencoder implementation that integrates seamlessly with the Hugging Face Transformers ecosystem, providing all the standard functionality you expect from transformer models.
@@ -11,28 +25,13 @@ A complete autoencoder implementation that integrates seamlessly with the Huggin
11
  - **Multiple Loss Functions**: Support for MSE, BCE, L1, Huber, Smooth L1, KL Divergence, Cosine, Focal, Dice, Tversky, SSIM, and Perceptual loss
12
  - **Multiple Autoencoder Types (7)**: Classic, Variational (VAE), Beta-VAE, Denoising, Sparse, Contractive, and Recurrent autoencoders
13
  - **Extended Activation Functions**: 18+ activation functions including ReLU, GELU, Swish, Mish, ELU, and more
14
- - **Learnable Preprocessing**: Neural Scaler and Normalizing Flow preprocessors (2D and 3D tensors)
15
  - **Extensible Design**: Easy to extend for new autoencoder variants and custom loss functions
16
  - **Production Ready**: Proper serialization, checkpointing, and inference support
17
 
18
- ## 📦 Installation
19
-
20
- ```bash
21
- uv sync # or: pip install -e .
22
- ```
23
-
24
- Dependencies (see pyproject.toml):
25
- - `torch>=2.8.0`
26
- - `transformers>=4.55.2`
27
- - `numpy>=2.3.2`
28
- - `scikit-learn>=1.7.1`
29
- - `datasets>=4.0.0`
30
- - `accelerate>=1.10.0`
31
 
32
  ## 🏗️ Architecture
33
 
34
- Note: This repository has been trimmed to essentials for easy reuse and distribution. Example scripts and tests were removed by request.
35
-
36
  The implementation consists of three main components:
37
 
38
  ### 1. AutoencoderConfig
@@ -72,7 +71,7 @@ config = AutoencoderConfig(
72
  autoencoder_type="classic", # Autoencoder type (7 types available)
73
  # Optional learnable preprocessing
74
  use_learnable_preprocessing=True,
75
- preprocessing_type="neural_scaler", # or "normalizing_flow"
76
  )
77
 
78
  # Create model
@@ -96,10 +95,10 @@ from torch.utils.data import Dataset
96
  class AutoencoderDataset(Dataset):
97
  def __init__(self, data):
98
  self.data = torch.FloatTensor(data)
99
-
100
  def __len__(self):
101
  return len(self.data)
102
-
103
  def __getitem__(self, idx):
104
  return {
105
  "input_values": self.data[idx],
@@ -263,7 +262,11 @@ config = AutoencoderConfig(
263
  ## 📊 Model Outputs
264
 
265
  ### AutoencoderOutput
 
 
 
266
  ```python
 
267
  @dataclass
268
  class AutoencoderOutput(ModelOutput):
269
  last_hidden_state: torch.FloatTensor = None # Latent representation
@@ -346,6 +349,33 @@ outputs = model(input_values=data)
346
  print(f"Preprocessing loss: {outputs.preprocessing_loss}")
347
  ```
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  ### Variational Autoencoder Extension
350
 
351
  The configuration supports variational autoencoders:
@@ -376,10 +406,6 @@ trainer = Trainer(
376
  )
377
  ```
378
 
379
- ## 🧪 Testing
380
-
381
- This repository has been trimmed to essential files. Example scripts and test files were removed by request. You can create your own quick checks using the Quick Start snippet above.
382
-
383
  ## 📁 Project Structure
384
 
385
  ```
 
1
+ ---
2
+ # Metadata for Hugging Face repo card
3
+ library_name: transformers
4
+ pipeline_tag: feature-extraction
5
+ license: apache-2.0
6
+ tags:
7
+ - autoencoder
8
+ - pytorch
9
+ - reconstruction
10
+ - preprocessing
11
+ - normalizing-flow
12
+ - scaler
13
+ ---
14
+
15
  # Autoencoder Implementation for Hugging Face Transformers
16
 
17
  A complete autoencoder implementation that integrates seamlessly with the Hugging Face Transformers ecosystem, providing all the standard functionality you expect from transformer models.
 
25
  - **Multiple Loss Functions**: Support for MSE, BCE, L1, Huber, Smooth L1, KL Divergence, Cosine, Focal, Dice, Tversky, SSIM, and Perceptual loss
26
  - **Multiple Autoencoder Types (7)**: Classic, Variational (VAE), Beta-VAE, Denoising, Sparse, Contractive, and Recurrent autoencoders
27
  - **Extended Activation Functions**: 18+ activation functions including ReLU, GELU, Swish, Mish, ELU, and more
28
+ - **Learnable Preprocessing**: Neural Scaler, Normalizing Flow, MinMax Scaler (learnable), Robust Scaler (learnable), and Yeo-Johnson preprocessors (2D and 3D tensors)
29
  - **Extensible Design**: Easy to extend for new autoencoder variants and custom loss functions
30
  - **Production Ready**: Proper serialization, checkpointing, and inference support
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  ## 🏗️ Architecture
34
 
 
 
35
  The implementation consists of three main components:
36
 
37
  ### 1. AutoencoderConfig
 
71
  autoencoder_type="classic", # Autoencoder type (7 types available)
72
  # Optional learnable preprocessing
73
  use_learnable_preprocessing=True,
74
+ preprocessing_type="neural_scaler", # or "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson"
75
  )
76
 
77
  # Create model
 
95
  class AutoencoderDataset(Dataset):
96
  def __init__(self, data):
97
  self.data = torch.FloatTensor(data)
98
+
99
  def __len__(self):
100
  return len(self.data)
101
+
102
  def __getitem__(self, idx):
103
  return {
104
  "input_values": self.data[idx],
 
262
  ## 📊 Model Outputs
263
 
264
  ### AutoencoderOutput
265
+
266
+ The base model `AutoencoderModel` returns the following output:
267
+ ```
268
  ```python
269
+
270
  @dataclass
271
  class AutoencoderOutput(ModelOutput):
272
  last_hidden_state: torch.FloatTensor = None # Latent representation
 
349
  print(f"Preprocessing loss: {outputs.preprocessing_loss}")
350
  ```
351
 
352
+ ```python
353
+ # Learnable MinMax Scaler - scales to [0, 1] with learnable bounds
354
+ config = AutoencoderConfig(
355
+ input_dim=20,
356
+ latent_dim=10,
357
+ use_learnable_preprocessing=True,
358
+ preprocessing_type="minmax_scaler",
359
+ )
360
+
361
+ # Learnable Robust Scaler - robust to outliers using median/IQR
362
+ config = AutoencoderConfig(
363
+ input_dim=20,
364
+ latent_dim=10,
365
+ use_learnable_preprocessing=True,
366
+ preprocessing_type="robust_scaler",
367
+ )
368
+
369
+ # Learnable Yeo-Johnson - power transform for skewed distributions
370
+ config = AutoencoderConfig(
371
+ input_dim=20,
372
+ latent_dim=10,
373
+ use_learnable_preprocessing=True,
374
+ preprocessing_type="yeo_johnson",
375
+ )
376
+ ```
377
+
378
+
379
  ### Variational Autoencoder Extension
380
 
381
  The configuration supports variational autoencoders:
 
406
  )
407
  ```
408
 
 
 
 
 
409
  ## 📁 Project Structure
410
 
411
  ```
configuration_autoencoder.py CHANGED
@@ -43,7 +43,7 @@ class AutoencoderConfig(PretrainedConfig):
43
  Defaults to 0.5.
44
  use_learnable_preprocessing (bool, optional): Whether to use learnable preprocessing. Defaults to False.
45
  preprocessing_type (str, optional): Type of learnable preprocessing. Options: "none", "neural_scaler",
46
- "normalizing_flow". Defaults to "none".
47
  preprocessing_hidden_dim (int, optional): Hidden dimension for preprocessing networks. Defaults to 64.
48
  preprocessing_num_layers (int, optional): Number of layers in preprocessing networks. Defaults to 2.
49
  learn_inverse_preprocessing (bool, optional): Whether to learn inverse preprocessing for reconstruction.
@@ -147,7 +147,14 @@ class AutoencoderConfig(PretrainedConfig):
147
  raise ValueError(f"`sequence_length` must be positive when specified, got {sequence_length}.")
148
 
149
  # Preprocessing validation
150
- valid_preprocessing = ["none", "neural_scaler", "normalizing_flow"]
 
 
 
 
 
 
 
151
  if preprocessing_type not in valid_preprocessing:
152
  raise ValueError(
153
  f"`preprocessing_type` must be one of {valid_preprocessing}, got {preprocessing_type}."
@@ -244,7 +251,22 @@ class AutoencoderConfig(PretrainedConfig):
244
  def is_normalizing_flow(self) -> bool:
245
  """Check if using normalizing flow preprocessing."""
246
  return self.preprocessing_type == "normalizing_flow"
247
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def to_dict(self):
249
  """
250
  Serializes this instance to a Python dictionary.
 
43
  Defaults to 0.5.
44
  use_learnable_preprocessing (bool, optional): Whether to use learnable preprocessing. Defaults to False.
45
  preprocessing_type (str, optional): Type of learnable preprocessing. Options: "none", "neural_scaler",
46
+ "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson". Defaults to "none".
47
  preprocessing_hidden_dim (int, optional): Hidden dimension for preprocessing networks. Defaults to 64.
48
  preprocessing_num_layers (int, optional): Number of layers in preprocessing networks. Defaults to 2.
49
  learn_inverse_preprocessing (bool, optional): Whether to learn inverse preprocessing for reconstruction.
 
147
  raise ValueError(f"`sequence_length` must be positive when specified, got {sequence_length}.")
148
 
149
  # Preprocessing validation
150
+ valid_preprocessing = [
151
+ "none",
152
+ "neural_scaler",
153
+ "normalizing_flow",
154
+ "minmax_scaler",
155
+ "robust_scaler",
156
+ "yeo_johnson",
157
+ ]
158
  if preprocessing_type not in valid_preprocessing:
159
  raise ValueError(
160
  f"`preprocessing_type` must be one of {valid_preprocessing}, got {preprocessing_type}."
 
251
  def is_normalizing_flow(self) -> bool:
252
  """Check if using normalizing flow preprocessing."""
253
  return self.preprocessing_type == "normalizing_flow"
254
+
255
+ @property
256
+ def is_minmax_scaler(self) -> bool:
257
+ """Check if using learnable MinMax scaler preprocessing."""
258
+ return self.preprocessing_type == "minmax_scaler"
259
+
260
+ @property
261
+ def is_robust_scaler(self) -> bool:
262
+ """Check if using learnable Robust scaler preprocessing."""
263
+ return self.preprocessing_type == "robust_scaler"
264
+
265
+ @property
266
+ def is_yeo_johnson(self) -> bool:
267
+ """Check if using learnable Yeo-Johnson power transform preprocessing."""
268
+ return self.preprocessing_type == "yeo_johnson"
269
+
270
  def to_dict(self):
271
  """
272
  Serializes this instance to a Python dictionary.
modeling_autoencoder.py CHANGED
@@ -143,6 +143,338 @@ class NeuralScaler(nn.Module):
143
  return x, torch.tensor(0.0, device=x.device)
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  class CouplingLayer(nn.Module):
147
  """Coupling layer for normalizing flows."""
148
 
@@ -306,6 +638,12 @@ class LearnablePreprocessor(nn.Module):
306
  self.preprocessor = NeuralScaler(config)
307
  elif config.is_normalizing_flow:
308
  self.preprocessor = NormalizingFlowPreprocessor(config)
 
 
 
 
 
 
309
  else:
310
  raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
311
 
@@ -399,7 +737,7 @@ class AutoencoderEncoder(nn.Module):
399
  else:
400
  # Standard encoder output
401
  self.fc_out = nn.Linear(input_dim, config.latent_dim)
402
-
403
  def _get_activation(self, activation: str) -> nn.Module:
404
  """Get activation function by name."""
405
  activations = {
@@ -423,7 +761,7 @@ class AutoencoderEncoder(nn.Module):
423
  "threshold": nn.Threshold(threshold=0.1, value=0),
424
  }
425
  return activations[activation]
426
-
427
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
428
  """Forward pass through encoder."""
429
  # Add noise for denoising autoencoders
@@ -461,37 +799,37 @@ class AutoencoderEncoder(nn.Module):
461
 
462
  class AutoencoderDecoder(nn.Module):
463
  """Decoder part of the autoencoder."""
464
-
465
  def __init__(self, config: AutoencoderConfig):
466
  super().__init__()
467
  self.config = config
468
-
469
  # Build decoder layers (reverse of encoder)
470
  layers = []
471
  input_dim = config.latent_dim
472
  decoder_dims = config.decoder_dims + [config.input_dim]
473
-
474
  for i, hidden_dim in enumerate(decoder_dims):
475
  layers.append(nn.Linear(input_dim, hidden_dim))
476
-
477
  # Don't add batch norm, activation, or dropout to the final layer
478
  if i < len(decoder_dims) - 1:
479
  if config.use_batch_norm:
480
  layers.append(nn.BatchNorm1d(hidden_dim))
481
-
482
  layers.append(self._get_activation(config.activation))
483
-
484
  if config.dropout_rate > 0:
485
  layers.append(nn.Dropout(config.dropout_rate))
486
  else:
487
  # Final layer - add appropriate activation based on reconstruction loss
488
  if config.reconstruction_loss == "bce":
489
  layers.append(nn.Sigmoid())
490
-
491
  input_dim = hidden_dim
492
-
493
  self.decoder = nn.Sequential(*layers)
494
-
495
  def _get_activation(self, activation: str) -> nn.Module:
496
  """Get activation function by name."""
497
  activations = {
@@ -515,7 +853,7 @@ class AutoencoderDecoder(nn.Module):
515
  "threshold": nn.Threshold(threshold=0.1, value=0),
516
  }
517
  return activations[activation]
518
-
519
  def forward(self, x: torch.Tensor) -> torch.Tensor:
520
  """Forward pass through decoder."""
521
  return self.decoder(x)
@@ -753,19 +1091,19 @@ class RecurrentDecoder(nn.Module):
753
  class AutoencoderModel(PreTrainedModel):
754
  """
755
  The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top.
756
-
757
  This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the
758
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
759
  etc.)
760
-
761
  This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the
762
  PyTorch documentation for all matter related to general usage and behavior.
763
  """
764
-
765
  config_class = AutoencoderConfig
766
  base_model_prefix = "autoencoder"
767
  supports_gradient_checkpointing = False
768
-
769
  def __init__(self, config: AutoencoderConfig):
770
  super().__init__(config)
771
  self.config = config
@@ -787,23 +1125,23 @@ class AutoencoderModel(PreTrainedModel):
787
  # Tie weights if specified
788
  if config.tie_weights:
789
  self._tie_weights()
790
-
791
  # Initialize weights
792
  self.post_init()
793
-
794
  def _tie_weights(self):
795
  """Tie encoder and decoder weights (transpose relationship)."""
796
  # This is a simplified weight tying - in practice, you might want more sophisticated tying
797
  pass
798
-
799
  def get_input_embeddings(self):
800
  """Get input embeddings (not applicable for basic autoencoder)."""
801
  return None
802
-
803
  def set_input_embeddings(self, value):
804
  """Set input embeddings (not applicable for basic autoencoder)."""
805
  pass
806
-
807
  def forward(
808
  self,
809
  input_values: torch.Tensor,
 
143
  return x, torch.tensor(0.0, device=x.device)
144
 
145
 
146
+
147
+ class LearnableMinMaxScaler(nn.Module):
148
+ """Learnable MinMax scaler that adapts bounds during training.
149
+
150
+ Scales features to [0, 1] using batch min/range with learnable adjustments and
151
+ a learnable affine transform. Supports 2D (B, F) and 3D (B, T, F) inputs.
152
+ """
153
+
154
+ def __init__(self, config: AutoencoderConfig):
155
+ super().__init__()
156
+ self.config = config
157
+ input_dim = config.input_dim
158
+ hidden_dim = config.preprocessing_hidden_dim
159
+
160
+ # Networks to learn adjustments to batch min and range
161
+ self.min_estimator = nn.Sequential(
162
+ nn.Linear(input_dim, hidden_dim),
163
+ nn.ReLU(),
164
+ nn.Linear(hidden_dim, hidden_dim),
165
+ nn.ReLU(),
166
+ nn.Linear(hidden_dim, input_dim),
167
+ )
168
+ self.range_estimator = nn.Sequential(
169
+ nn.Linear(input_dim, hidden_dim),
170
+ nn.ReLU(),
171
+ nn.Linear(hidden_dim, hidden_dim),
172
+ nn.ReLU(),
173
+ nn.Linear(hidden_dim, input_dim),
174
+ nn.Softplus(), # Ensure positive adjustment to range
175
+ )
176
+
177
+ # Learnable affine transformation parameters
178
+ self.weight = nn.Parameter(torch.ones(input_dim))
179
+ self.bias = nn.Parameter(torch.zeros(input_dim))
180
+
181
+ # Running statistics for inference
182
+ self.register_buffer("running_min", torch.zeros(input_dim))
183
+ self.register_buffer("running_range", torch.ones(input_dim))
184
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
185
+
186
+ self.momentum = 0.1
187
+
188
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ if inverse:
190
+ return self._inverse_transform(x)
191
+
192
+ original_shape = x.shape
193
+ if x.dim() == 3:
194
+ x = x.view(-1, x.size(-1))
195
+
196
+ eps = 1e-8
197
+ if self.training:
198
+ batch_min = x.min(dim=0, keepdim=True).values
199
+ batch_max = x.max(dim=0, keepdim=True).values
200
+ batch_range = (batch_max - batch_min).clamp_min(eps)
201
+
202
+ # Learn adjustments
203
+ learned_min_adj = self.min_estimator(batch_min)
204
+ learned_range_adj = self.range_estimator(batch_range)
205
+
206
+ effective_min = batch_min + learned_min_adj
207
+ effective_range = batch_range + learned_range_adj + eps
208
+
209
+ # Update running stats with raw batch min/range for stable inversion
210
+ with torch.no_grad():
211
+ self.num_batches_tracked += 1
212
+ if self.num_batches_tracked == 1:
213
+ self.running_min.copy_(batch_min.squeeze())
214
+ self.running_range.copy_(batch_range.squeeze())
215
+ else:
216
+ self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum)
217
+ self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum)
218
+ else:
219
+ effective_min = self.running_min.unsqueeze(0)
220
+ effective_range = self.running_range.unsqueeze(0)
221
+
222
+ # Scale to [0, 1]
223
+ scaled = (x - effective_min) / effective_range
224
+
225
+ # Learnable affine transform
226
+ transformed = scaled * self.weight + self.bias
227
+
228
+ if len(original_shape) == 3:
229
+ transformed = transformed.view(original_shape)
230
+
231
+ # Regularization: encourage non-degenerate range and modest affine params
232
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
233
+ if self.training:
234
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean()
235
+
236
+ return transformed, reg_loss
237
+
238
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
239
+ if not self.config.learn_inverse_preprocessing:
240
+ return x, torch.tensor(0.0, device=x.device)
241
+
242
+ original_shape = x.shape
243
+ if x.dim() == 3:
244
+ x = x.view(-1, x.size(-1))
245
+
246
+ # Reverse affine
247
+ x = (x - self.bias) / (self.weight + 1e-8)
248
+ # Reverse MinMax using running stats
249
+ x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0)
250
+
251
+ if len(original_shape) == 3:
252
+ x = x.view(original_shape)
253
+
254
+ return x, torch.tensor(0.0, device=x.device)
255
+
256
+
257
+ class LearnableRobustScaler(nn.Module):
258
+ """Learnable Robust scaler using median and IQR with learnable adjustments.
259
+
260
+ Normalizes as (x - median) / IQR with learnable adjustments and an affine head.
261
+ Supports 2D (B, F) and 3D (B, T, F) inputs.
262
+ """
263
+
264
+ def __init__(self, config: AutoencoderConfig):
265
+ super().__init__()
266
+ self.config = config
267
+ input_dim = config.input_dim
268
+ hidden_dim = config.preprocessing_hidden_dim
269
+
270
+ self.median_estimator = nn.Sequential(
271
+ nn.Linear(input_dim, hidden_dim),
272
+ nn.ReLU(),
273
+ nn.Linear(hidden_dim, hidden_dim),
274
+ nn.ReLU(),
275
+ nn.Linear(hidden_dim, input_dim),
276
+ )
277
+ self.iqr_estimator = nn.Sequential(
278
+ nn.Linear(input_dim, hidden_dim),
279
+ nn.ReLU(),
280
+ nn.Linear(hidden_dim, hidden_dim),
281
+ nn.ReLU(),
282
+ nn.Linear(hidden_dim, input_dim),
283
+ nn.Softplus(), # Ensure positive IQR adjustment
284
+ )
285
+
286
+ self.weight = nn.Parameter(torch.ones(input_dim))
287
+ self.bias = nn.Parameter(torch.zeros(input_dim))
288
+
289
+ self.register_buffer("running_median", torch.zeros(input_dim))
290
+ self.register_buffer("running_iqr", torch.ones(input_dim))
291
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
292
+
293
+ self.momentum = 0.1
294
+
295
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ if inverse:
297
+ return self._inverse_transform(x)
298
+
299
+ original_shape = x.shape
300
+ if x.dim() == 3:
301
+ x = x.view(-1, x.size(-1))
302
+
303
+ eps = 1e-8
304
+ if self.training:
305
+ qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0)
306
+ q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :]
307
+ iqr = (q75 - q25).clamp_min(eps)
308
+
309
+ learned_med_adj = self.median_estimator(med)
310
+ learned_iqr_adj = self.iqr_estimator(iqr)
311
+
312
+ effective_median = med + learned_med_adj
313
+ effective_iqr = iqr + learned_iqr_adj + eps
314
+
315
+ with torch.no_grad():
316
+ self.num_batches_tracked += 1
317
+ if self.num_batches_tracked == 1:
318
+ self.running_median.copy_(med.squeeze())
319
+ self.running_iqr.copy_(iqr.squeeze())
320
+ else:
321
+ self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum)
322
+ self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum)
323
+ else:
324
+ effective_median = self.running_median.unsqueeze(0)
325
+ effective_iqr = self.running_iqr.unsqueeze(0)
326
+
327
+ normalized = (x - effective_median) / effective_iqr
328
+ transformed = normalized * self.weight + self.bias
329
+
330
+ if len(original_shape) == 3:
331
+ transformed = transformed.view(original_shape)
332
+
333
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
334
+ if self.training:
335
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean()
336
+
337
+ return transformed, reg_loss
338
+
339
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
340
+ if not self.config.learn_inverse_preprocessing:
341
+ return x, torch.tensor(0.0, device=x.device)
342
+
343
+ original_shape = x.shape
344
+ if x.dim() == 3:
345
+ x = x.view(-1, x.size(-1))
346
+
347
+ x = (x - self.bias) / (self.weight + 1e-8)
348
+ x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0)
349
+
350
+ if len(original_shape) == 3:
351
+ x = x.view(original_shape)
352
+
353
+ return x, torch.tensor(0.0, device=x.device)
354
+
355
+
356
+ class LearnableYeoJohnsonPreprocessor(nn.Module):
357
+ """Learnable Yeo-Johnson power transform with per-feature λ and affine head.
358
+
359
+ Applies Yeo-Johnson transform elementwise with learnable lambda per feature,
360
+ followed by standardization and a learnable affine transform. Supports 2D and 3D inputs.
361
+ """
362
+
363
+ def __init__(self, config: AutoencoderConfig):
364
+ super().__init__()
365
+ self.config = config
366
+ input_dim = config.input_dim
367
+
368
+ # Learnable lambda per feature (unconstrained). Initialize around 1.0
369
+ self.lmbda = nn.Parameter(torch.ones(input_dim))
370
+
371
+ # Learnable affine parameters after standardization
372
+ self.weight = nn.Parameter(torch.ones(input_dim))
373
+ self.bias = nn.Parameter(torch.zeros(input_dim))
374
+
375
+ # Running stats for transformed data
376
+ self.register_buffer("running_mean", torch.zeros(input_dim))
377
+ self.register_buffer("running_std", torch.ones(input_dim))
378
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
379
+ self.momentum = 0.1
380
+
381
+ def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
382
+ eps = 1e-6
383
+ lmbda = lmbda.unsqueeze(0) # broadcast over batch
384
+ pos = x >= 0
385
+ # For x >= 0
386
+ if_part = torch.where(
387
+ torch.abs(lmbda) > eps,
388
+ ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda,
389
+ torch.log((x + 1.0).clamp_min(eps)),
390
+ )
391
+ # For x < 0
392
+ two_minus_lambda = 2.0 - lmbda
393
+ else_part = torch.where(
394
+ torch.abs(two_minus_lambda) > eps,
395
+ -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda,
396
+ -torch.log((1.0 - x).clamp_min(eps)),
397
+ )
398
+ return torch.where(pos, if_part, else_part)
399
+
400
+ def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
401
+ eps = 1e-6
402
+ lmbda = lmbda.unsqueeze(0)
403
+ pos = y >= 0
404
+ # Inverse for y >= 0
405
+ x_pos = torch.where(
406
+ torch.abs(lmbda) > eps,
407
+ (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0,
408
+ torch.exp(y) - 1.0,
409
+ )
410
+ # Inverse for y < 0
411
+ two_minus_lambda = 2.0 - lmbda
412
+ x_neg = torch.where(
413
+ torch.abs(two_minus_lambda) > eps,
414
+ 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda),
415
+ 1.0 - torch.exp(-y),
416
+ )
417
+ return torch.where(pos, x_pos, x_neg)
418
+
419
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
420
+ if inverse:
421
+ return self._inverse_transform(x)
422
+
423
+ orig_shape = x.shape
424
+ if x.dim() == 3:
425
+ x = x.view(-1, x.size(-1))
426
+
427
+ # Apply Yeo-Johnson
428
+ y = self._yeo_johnson(x, self.lmbda)
429
+
430
+ # Batch stats and running stats on transformed data
431
+ if self.training:
432
+ batch_mean = y.mean(dim=0, keepdim=True)
433
+ batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6)
434
+ with torch.no_grad():
435
+ self.num_batches_tracked += 1
436
+ if self.num_batches_tracked == 1:
437
+ self.running_mean.copy_(batch_mean.squeeze())
438
+ self.running_std.copy_(batch_std.squeeze())
439
+ else:
440
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
441
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
442
+ mean = batch_mean
443
+ std = batch_std
444
+ else:
445
+ mean = self.running_mean.unsqueeze(0)
446
+ std = self.running_std.unsqueeze(0)
447
+
448
+ y_norm = (y - mean) / std
449
+ out = y_norm * self.weight + self.bias
450
+
451
+ if len(orig_shape) == 3:
452
+ out = out.view(orig_shape)
453
+
454
+ # Regularize lambda to avoid extreme values; encourage identity around 1
455
+ reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var())
456
+ return out, reg
457
+
458
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
459
+ if not self.config.learn_inverse_preprocessing:
460
+ return x, torch.tensor(0.0, device=x.device)
461
+
462
+ orig_shape = x.shape
463
+ if x.dim() == 3:
464
+ x = x.view(-1, x.size(-1))
465
+
466
+ # Reverse affine and normalization with running stats
467
+ y = (x - self.bias) / (self.weight + 1e-8)
468
+ y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0)
469
+
470
+ # Inverse Yeo-Johnson
471
+ out = self._yeo_johnson_inverse(y, self.lmbda)
472
+
473
+ if len(orig_shape) == 3:
474
+ out = out.view(orig_shape)
475
+
476
+ return out, torch.tensor(0.0, device=x.device)
477
+
478
  class CouplingLayer(nn.Module):
479
  """Coupling layer for normalizing flows."""
480
 
 
638
  self.preprocessor = NeuralScaler(config)
639
  elif config.is_normalizing_flow:
640
  self.preprocessor = NormalizingFlowPreprocessor(config)
641
+ elif getattr(config, "is_minmax_scaler", False):
642
+ self.preprocessor = LearnableMinMaxScaler(config)
643
+ elif getattr(config, "is_robust_scaler", False):
644
+ self.preprocessor = LearnableRobustScaler(config)
645
+ elif getattr(config, "is_yeo_johnson", False):
646
+ self.preprocessor = LearnableYeoJohnsonPreprocessor(config)
647
  else:
648
  raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
649
 
 
737
  else:
738
  # Standard encoder output
739
  self.fc_out = nn.Linear(input_dim, config.latent_dim)
740
+
741
  def _get_activation(self, activation: str) -> nn.Module:
742
  """Get activation function by name."""
743
  activations = {
 
761
  "threshold": nn.Threshold(threshold=0.1, value=0),
762
  }
763
  return activations[activation]
764
+
765
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
766
  """Forward pass through encoder."""
767
  # Add noise for denoising autoencoders
 
799
 
800
  class AutoencoderDecoder(nn.Module):
801
  """Decoder part of the autoencoder."""
802
+
803
  def __init__(self, config: AutoencoderConfig):
804
  super().__init__()
805
  self.config = config
806
+
807
  # Build decoder layers (reverse of encoder)
808
  layers = []
809
  input_dim = config.latent_dim
810
  decoder_dims = config.decoder_dims + [config.input_dim]
811
+
812
  for i, hidden_dim in enumerate(decoder_dims):
813
  layers.append(nn.Linear(input_dim, hidden_dim))
814
+
815
  # Don't add batch norm, activation, or dropout to the final layer
816
  if i < len(decoder_dims) - 1:
817
  if config.use_batch_norm:
818
  layers.append(nn.BatchNorm1d(hidden_dim))
819
+
820
  layers.append(self._get_activation(config.activation))
821
+
822
  if config.dropout_rate > 0:
823
  layers.append(nn.Dropout(config.dropout_rate))
824
  else:
825
  # Final layer - add appropriate activation based on reconstruction loss
826
  if config.reconstruction_loss == "bce":
827
  layers.append(nn.Sigmoid())
828
+
829
  input_dim = hidden_dim
830
+
831
  self.decoder = nn.Sequential(*layers)
832
+
833
  def _get_activation(self, activation: str) -> nn.Module:
834
  """Get activation function by name."""
835
  activations = {
 
853
  "threshold": nn.Threshold(threshold=0.1, value=0),
854
  }
855
  return activations[activation]
856
+
857
  def forward(self, x: torch.Tensor) -> torch.Tensor:
858
  """Forward pass through decoder."""
859
  return self.decoder(x)
 
1091
  class AutoencoderModel(PreTrainedModel):
1092
  """
1093
  The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top.
1094
+
1095
  This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the
1096
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1097
  etc.)
1098
+
1099
  This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the
1100
  PyTorch documentation for all matter related to general usage and behavior.
1101
  """
1102
+
1103
  config_class = AutoencoderConfig
1104
  base_model_prefix = "autoencoder"
1105
  supports_gradient_checkpointing = False
1106
+
1107
  def __init__(self, config: AutoencoderConfig):
1108
  super().__init__(config)
1109
  self.config = config
 
1125
  # Tie weights if specified
1126
  if config.tie_weights:
1127
  self._tie_weights()
1128
+
1129
  # Initialize weights
1130
  self.post_init()
1131
+
1132
  def _tie_weights(self):
1133
  """Tie encoder and decoder weights (transpose relationship)."""
1134
  # This is a simplified weight tying - in practice, you might want more sophisticated tying
1135
  pass
1136
+
1137
  def get_input_embeddings(self):
1138
  """Get input embeddings (not applicable for basic autoencoder)."""
1139
  return None
1140
+
1141
  def set_input_embeddings(self, value):
1142
  """Set input embeddings (not applicable for basic autoencoder)."""
1143
  pass
1144
+
1145
  def forward(
1146
  self,
1147
  input_values: torch.Tensor,