amaye15 commited on
Commit
ced6e93
·
0 Parent(s):

Initial commit: AutoEncoder model

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
README.md ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Autoencoder Implementation for Hugging Face Transformers
2
+
3
+ A complete autoencoder implementation that integrates seamlessly with the Hugging Face Transformers ecosystem, providing all the standard functionality you expect from transformer models.
4
+
5
+ ## 🚀 Features
6
+
7
+ - **Full Hugging Face Integration**: Compatible with `AutoModel`, `AutoConfig`, and `AutoTokenizer` patterns
8
+ - **Standard Training Workflows**: Works with `Trainer`, `TrainingArguments`, and all HF training utilities
9
+ - **Model Hub Compatible**: Save and share models on Hugging Face Hub with `push_to_hub()`
10
+ - **Flexible Architecture**: Configurable encoder-decoder architecture with various activation functions
11
+ - **Multiple Loss Functions**: Support for MSE, BCE, L1, Huber, Smooth L1, KL Divergence, Cosine, Focal, Dice, Tversky, SSIM, and Perceptual loss
12
+ - **Multiple Autoencoder Types (7)**: Classic, Variational (VAE), Beta-VAE, Denoising, Sparse, Contractive, and Recurrent autoencoders
13
+ - **Extended Activation Functions**: 18+ activation functions including ReLU, GELU, Swish, Mish, ELU, and more
14
+ - **Learnable Preprocessing**: Neural Scaler and Normalizing Flow preprocessors (2D and 3D tensors)
15
+ - **Extensible Design**: Easy to extend for new autoencoder variants and custom loss functions
16
+ - **Production Ready**: Proper serialization, checkpointing, and inference support
17
+
18
+ ## 📦 Installation
19
+
20
+ ```bash
21
+ uv sync # or: pip install -e .
22
+ ```
23
+
24
+ Dependencies (see pyproject.toml):
25
+ - `torch>=2.8.0`
26
+ - `transformers>=4.55.2`
27
+ - `numpy>=2.3.2`
28
+ - `scikit-learn>=1.7.1`
29
+ - `datasets>=4.0.0`
30
+ - `accelerate>=1.10.0`
31
+
32
+ ## 🏗️ Architecture
33
+
34
+ Note: This repository has been trimmed to essentials for easy reuse and distribution. Example scripts and tests were removed by request.
35
+
36
+ The implementation consists of three main components:
37
+
38
+ ### 1. AutoencoderConfig
39
+ Configuration class that inherits from `PretrainedConfig`:
40
+ - Defines model architecture parameters
41
+ - Handles validation and serialization
42
+ - Enables `AutoConfig.from_pretrained()` functionality
43
+
44
+ ### 2. AutoencoderModel
45
+ Base model class that inherits from `PreTrainedModel`:
46
+ - Implements encoder-decoder architecture
47
+ - Provides latent space representation
48
+ - Returns structured outputs with `AutoencoderOutput`
49
+
50
+ ### 3. AutoencoderForReconstruction
51
+ Task-specific model for reconstruction:
52
+ - Adds reconstruction loss calculation
53
+ - Compatible with `Trainer` for easy training
54
+ - Returns `AutoencoderForReconstructionOutput` with loss
55
+
56
+ ## 🔧 Quick Start
57
+
58
+ ### Basic Usage
59
+
60
+ ```python
61
+ from configuration_autoencoder import AutoencoderConfig
62
+ from modeling_autoencoder import AutoencoderForReconstruction
63
+ import torch
64
+
65
+ # Create configuration
66
+ config = AutoencoderConfig(
67
+ input_dim=784, # Input dimensionality (e.g., 28x28 images flattened)
68
+ hidden_dims=[512, 256], # Encoder hidden layers
69
+ latent_dim=64, # Latent space dimension
70
+ activation="gelu", # Activation function (18+ options available)
71
+ reconstruction_loss="mse", # Loss function (12+ options available)
72
+ autoencoder_type="classic", # Autoencoder type (7 types available)
73
+ # Optional learnable preprocessing
74
+ use_learnable_preprocessing=True,
75
+ preprocessing_type="neural_scaler", # or "normalizing_flow"
76
+ )
77
+
78
+ # Create model
79
+ model = AutoencoderForReconstruction(config)
80
+
81
+ # Forward pass
82
+ input_data = torch.randn(32, 784) # Batch of 32 samples
83
+ outputs = model(input_values=input_data)
84
+
85
+ print(f"Reconstruction loss: {outputs.loss}")
86
+ print(f"Latent shape: {outputs.last_hidden_state.shape}")
87
+ print(f"Reconstructed shape: {outputs.reconstructed.shape}")
88
+ ```
89
+
90
+ ### Training with Hugging Face Trainer
91
+
92
+ ```python
93
+ from transformers import Trainer, TrainingArguments
94
+ from torch.utils.data import Dataset
95
+
96
+ class AutoencoderDataset(Dataset):
97
+ def __init__(self, data):
98
+ self.data = torch.FloatTensor(data)
99
+
100
+ def __len__(self):
101
+ return len(self.data)
102
+
103
+ def __getitem__(self, idx):
104
+ return {
105
+ "input_values": self.data[idx],
106
+ "labels": self.data[idx] # For autoencoder, input = target
107
+ }
108
+
109
+ # Prepare data
110
+ train_dataset = AutoencoderDataset(your_training_data)
111
+ val_dataset = AutoencoderDataset(your_validation_data)
112
+
113
+ # Training arguments
114
+ training_args = TrainingArguments(
115
+ output_dir="./autoencoder_output",
116
+ num_train_epochs=10,
117
+ per_device_train_batch_size=64,
118
+ per_device_eval_batch_size=64,
119
+ warmup_steps=500,
120
+ weight_decay=0.01,
121
+ logging_dir="./logs",
122
+ evaluation_strategy="steps",
123
+ eval_steps=500,
124
+ save_steps=1000,
125
+ load_best_model_at_end=True,
126
+ )
127
+
128
+ # Create trainer
129
+ trainer = Trainer(
130
+ model=model,
131
+ args=training_args,
132
+ train_dataset=train_dataset,
133
+ eval_dataset=val_dataset,
134
+ )
135
+
136
+ # Train
137
+ trainer.train()
138
+
139
+ # Save model
140
+ model.save_pretrained("./my_autoencoder")
141
+ config.save_pretrained("./my_autoencoder")
142
+ ```
143
+
144
+ ### Using AutoModel Framework
145
+
146
+ ```python
147
+ from register_autoencoder import register_autoencoder_models
148
+ from transformers import AutoConfig, AutoModel
149
+
150
+ # Register models with AutoModel framework
151
+ register_autoencoder_models()
152
+
153
+ # Now you can use standard HF patterns
154
+ config = AutoConfig.from_pretrained("./my_autoencoder")
155
+ model = AutoModel.from_pretrained("./my_autoencoder")
156
+
157
+ # Use the model
158
+ outputs = model(input_values=your_data)
159
+ ```
160
+
161
+ ## ⚙️ Configuration Options
162
+
163
+ The `AutoencoderConfig` class supports extensive customization:
164
+
165
+ ```python
166
+ config = AutoencoderConfig(
167
+ input_dim=784, # Input dimension
168
+ hidden_dims=[512, 256, 128], # Encoder hidden layers
169
+ latent_dim=64, # Latent space dimension
170
+ activation="gelu", # Activation function (see full list below)
171
+ dropout_rate=0.1, # Dropout rate (0.0 to 1.0)
172
+ use_batch_norm=True, # Use batch normalization
173
+ tie_weights=False, # Tie encoder/decoder weights
174
+ reconstruction_loss="mse", # Loss function (see full list below)
175
+ autoencoder_type="variational", # Autoencoder type (see types below)
176
+ beta=0.5, # Beta parameter for β-VAE
177
+ temperature=1.0, # Temperature for Gumbel softmax
178
+ noise_factor=0.1, # Noise factor for denoising AE
179
+ # Recurrent autoencoder parameters
180
+ rnn_type="lstm", # RNN type: "lstm", "gru", "rnn"
181
+ num_layers=2, # Number of RNN layers
182
+ bidirectional=True, # Bidirectional encoding
183
+ sequence_length=None, # Fixed sequence length (None for variable)
184
+ teacher_forcing_ratio=0.5, # Teacher forcing ratio during training
185
+ # Learnable preprocessing parameters
186
+ use_learnable_preprocessing=False, # Enable learnable preprocessing
187
+ preprocessing_type="none", # "none", "neural_scaler", "normalizing_flow"
188
+ preprocessing_hidden_dim=64, # Hidden dimension for preprocessing networks
189
+ preprocessing_num_layers=2, # Number of layers in preprocessing networks
190
+ learn_inverse_preprocessing=True, # Learn inverse transformation
191
+ flow_coupling_layers=4, # Number of coupling layers for flows
192
+ )
193
+ ```
194
+
195
+ ### 🎛️ Available Activation Functions
196
+
197
+ **Standard Activations:**
198
+ - `relu`, `leaky_relu`, `relu6`, `elu`, `prelu`
199
+ - `tanh`, `sigmoid`, `hardsigmoid`, `hardtanh`
200
+ - `gelu`, `swish`, `silu`, `hardswish`
201
+ - `mish`, `softplus`, `softsign`, `tanhshrink`, `threshold`
202
+
203
+ ### 📊 Available Loss Functions
204
+
205
+ **Regression Losses:**
206
+ - `mse` - Mean Squared Error
207
+ - `l1` - L1/MAE Loss
208
+ - `huber` - Huber Loss
209
+ - `smooth_l1` - Smooth L1 Loss
210
+
211
+ **Classification/Probability Losses:**
212
+ - `bce` - Binary Cross Entropy
213
+ - `kl_div` - KL Divergence
214
+ - `focal` - Focal Loss
215
+
216
+ **Similarity Losses:**
217
+ - `cosine` - Cosine Similarity Loss
218
+ - `ssim` - Structural Similarity Loss
219
+ - `perceptual` - Perceptual Loss
220
+
221
+ **Segmentation Losses:**
222
+ - `dice` - Dice Loss
223
+ - `tversky` - Tversky Loss
224
+
225
+ ### 🏗️ Available Autoencoder Types
226
+
227
+ **Classic Autoencoder (`classic`)**
228
+ - Standard encoder-decoder architecture
229
+ - Direct reconstruction loss minimization
230
+
231
+ **Variational Autoencoder (`variational`)**
232
+ - Probabilistic latent space with mean and variance
233
+ - KL divergence regularization
234
+ - Reparameterization trick for sampling
235
+
236
+ **Beta-VAE (`beta_vae`)**
237
+ - Variational autoencoder with adjustable β parameter
238
+ - Better disentanglement of latent factors
239
+
240
+ **Denoising Autoencoder (`denoising`)**
241
+ - Adds noise to input during training
242
+ - Learns robust representations
243
+ - Configurable noise factor
244
+
245
+ **Sparse Autoencoder (`sparse`)**
246
+ - Encourages sparse latent representations
247
+ - L1 regularization on latent activations
248
+ - Useful for feature selection
249
+
250
+ **Contractive Autoencoder (`contractive`)**
251
+ - Penalizes large gradients of latent w.r.t. input
252
+ - Learns smooth manifold representations
253
+ - Robust to small input perturbations
254
+
255
+ **Recurrent Autoencoder (`recurrent`)**
256
+ - LSTM/GRU/RNN encoder-decoder architecture
257
+ - Bidirectional encoding for better sequence representations
258
+ - Variable length sequence support with padding
259
+ - Teacher forcing during training for stable learning
260
+ - Sequence-to-sequence reconstruction
261
+ ```
262
+
263
+ ## 📊 Model Outputs
264
+
265
+ ### AutoencoderOutput
266
+ ```python
267
+ @dataclass
268
+ class AutoencoderOutput(ModelOutput):
269
+ last_hidden_state: torch.FloatTensor = None # Latent representation
270
+ reconstructed: torch.FloatTensor = None # Reconstructed input
271
+ hidden_states: Tuple[torch.FloatTensor] = None # Intermediate states
272
+ attentions: Tuple[torch.FloatTensor] = None # Not used
273
+ ```
274
+
275
+ ### AutoencoderForReconstructionOutput
276
+ ```python
277
+ @dataclass
278
+ class AutoencoderForReconstructionOutput(ModelOutput):
279
+ loss: torch.FloatTensor = None # Reconstruction loss
280
+ reconstructed: torch.FloatTensor = None # Reconstructed input
281
+ last_hidden_state: torch.FloatTensor = None # Latent representation
282
+ hidden_states: Tuple[torch.FloatTensor] = None # Intermediate states
283
+ ```
284
+
285
+ ## 🔬 Advanced Usage
286
+
287
+ ### Custom Loss Functions
288
+
289
+ You can easily extend the model with custom loss functions:
290
+
291
+ ```python
292
+ class CustomAutoencoder(AutoencoderForReconstruction):
293
+ def _compute_reconstruction_loss(self, reconstructed, target):
294
+ # Custom loss implementation
295
+ return your_custom_loss(reconstructed, target)
296
+ ```
297
+
298
+ ### Recurrent Autoencoder for Sequences
299
+
300
+ Perfect for time series, text, and sequential data:
301
+
302
+ ```python
303
+ config = AutoencoderConfig(
304
+ input_dim=50, # Feature dimension per timestep
305
+ latent_dim=32, # Compressed representation size
306
+ autoencoder_type="recurrent",
307
+ rnn_type="lstm", # or "gru", "rnn"
308
+ num_layers=2, # Number of RNN layers
309
+ bidirectional=True, # Bidirectional encoding
310
+ teacher_forcing_ratio=0.7, # Teacher forcing during training
311
+ sequence_length=None # Variable length sequences
312
+ )
313
+
314
+ # Usage with sequence data
315
+ model = AutoencoderForReconstruction(config)
316
+ sequence_data = torch.randn(batch_size, seq_len, input_dim)
317
+ outputs = model(input_values=sequence_data)
318
+ ```
319
+
320
+ ### Learnable Preprocessing
321
+
322
+ Deep learning-based data normalization that adapts to your data:
323
+
324
+ ```python
325
+ # Neural Scaler - Learnable alternative to StandardScaler
326
+ config = AutoencoderConfig(
327
+ input_dim=20,
328
+ latent_dim=10,
329
+ use_learnable_preprocessing=True,
330
+ preprocessing_type="neural_scaler",
331
+ preprocessing_hidden_dim=64
332
+ )
333
+
334
+ # Normalizing Flow - Invertible transformations
335
+ config = AutoencoderConfig(
336
+ input_dim=20,
337
+ latent_dim=10,
338
+ use_learnable_preprocessing=True,
339
+ preprocessing_type="normalizing_flow",
340
+ flow_coupling_layers=4
341
+ )
342
+
343
+ # Works with all autoencoder types and sequence data
344
+ model = AutoencoderForReconstruction(config)
345
+ outputs = model(input_values=data)
346
+ print(f"Preprocessing loss: {outputs.preprocessing_loss}")
347
+ ```
348
+
349
+ ### Variational Autoencoder Extension
350
+
351
+ The configuration supports variational autoencoders:
352
+
353
+ ```python
354
+ config = AutoencoderConfig(
355
+ autoencoder_type="variational",
356
+ beta=0.5, # β-VAE parameter
357
+ # ... other parameters
358
+ )
359
+ ```
360
+
361
+ ### Integration with Datasets Library
362
+
363
+ ```python
364
+ from datasets import Dataset
365
+
366
+ # Convert your data to HF Dataset
367
+ dataset = Dataset.from_dict({
368
+ "input_values": your_data_list
369
+ })
370
+
371
+ # Use with Trainer
372
+ trainer = Trainer(
373
+ model=model,
374
+ train_dataset=dataset,
375
+ # ... other arguments
376
+ )
377
+ ```
378
+
379
+ ## 🧪 Testing
380
+
381
+ This repository has been trimmed to essential files. Example scripts and test files were removed by request. You can create your own quick checks using the Quick Start snippet above.
382
+
383
+ ## 📁 Project Structure
384
+
385
+ ```
386
+ autoencoder/
387
+ ├── __init__.py # Package initialization
388
+ ├── configuration_autoencoder.py # Configuration class
389
+ ├── modeling_autoencoder.py # Model implementations
390
+ ├── register_autoencoder.py # AutoModel registration
391
+ ├── example_usage.py # Usage examples
392
+ ├── test_save_load.py # Test suite
393
+ ├── requirements.txt # Dependencies
394
+ └── README.md # This file
395
+ ```
396
+
397
+ ## 🤝 Contributing
398
+
399
+ This implementation follows Hugging Face conventions and can be easily extended:
400
+
401
+ 1. **Adding new architectures**: Extend `AutoencoderModel` or create new model classes
402
+ 2. **Custom configurations**: Add parameters to `AutoencoderConfig`
403
+ 3. **Task-specific heads**: Create new classes like `AutoencoderForReconstruction`
404
+ 4. **Integration**: Register new models with the AutoModel framework
405
+
406
+ ## 📚 References
407
+
408
+ - [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers)
409
+ - [Custom Models Guide](https://huggingface.co/docs/transformers/custom_models)
410
+ - [AutoModel Documentation](https://huggingface.co/docs/transformers/model_doc/auto)
411
+
412
+ ## 🎯 Use Cases
413
+
414
+ This autoencoder implementation is perfect for:
415
+
416
+ - **Dimensionality Reduction**: Compress high-dimensional data to lower dimensions
417
+ - **Anomaly Detection**: Identify outliers based on reconstruction error
418
+ - **Data Denoising**: Remove noise from corrupted data
419
+ - **Feature Learning**: Learn meaningful representations for downstream tasks
420
+ - **Data Generation**: Generate new samples similar to training data
421
+ - **Pretraining**: Initialize encoders for other tasks
422
+
423
+ ## 🔍 Model Comparison
424
+
425
+ | Feature | Standard PyTorch | This Implementation |
426
+ |---------|------------------|-------------------|
427
+ | HF Integration | ❌ | ✅ |
428
+ | AutoModel Support | ❌ | ✅ |
429
+ | Trainer Compatible | ❌ | ✅ |
430
+ | Hub Integration | ❌ | ✅ |
431
+ | Config Management | Manual | ✅ Automatic |
432
+ | Serialization | Manual | ✅ Built-in |
433
+ | Checkpointing | Manual | ✅ Built-in |
434
+
435
+ ## 🚀 Performance Tips
436
+
437
+ 1. **Batch Size**: Use larger batch sizes for better GPU utilization
438
+ 2. **Learning Rate**: Start with 1e-3 and adjust based on convergence
439
+ 3. **Architecture**: Gradually decrease hidden dimensions for better compression
440
+ 4. **Regularization**: Use dropout and batch normalization for better generalization
441
+ 5. **Loss Function**: Choose appropriate loss based on your data type
442
+
443
+ ## 📄 License
444
+
445
+ This implementation is provided as an example and follows the same license terms as Hugging Face Transformers.
__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Autoencoder models for Hugging Face Transformers.
3
+ """
4
+
5
+ from configuration_autoencoder import AutoencoderConfig
6
+ from modeling_autoencoder import (
7
+ AutoencoderModel,
8
+ AutoencoderForReconstruction,
9
+ AutoencoderOutput,
10
+ AutoencoderForReconstructionOutput,
11
+ )
12
+
13
+ __all__ = [
14
+ "AutoencoderConfig",
15
+ "AutoencoderModel",
16
+ "AutoencoderForReconstruction",
17
+ "AutoencoderOutput",
18
+ "AutoencoderForReconstructionOutput",
19
+ ]
configuration_autoencoder.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Autoencoder configuration for Hugging Face Transformers.
3
+ """
4
+
5
+ from transformers import PretrainedConfig
6
+ from typing import List, Optional
7
+
8
+
9
+ class AutoencoderConfig(PretrainedConfig):
10
+ """
11
+ Configuration class for Autoencoder models.
12
+
13
+ This configuration class stores the configuration of an autoencoder model. It is used to instantiate
14
+ an autoencoder model according to the specified arguments, defining the model architecture.
15
+
16
+ Args:
17
+ input_dim (int, optional): Dimensionality of the input data. Defaults to 784.
18
+ hidden_dims (List[int], optional): List of hidden layer dimensions for the encoder.
19
+ The decoder will use the reverse of this list. Defaults to [512, 256, 128].
20
+ latent_dim (int, optional): Dimensionality of the latent space. Defaults to 64.
21
+ activation (str, optional): Activation function to use. Options: "relu", "tanh", "sigmoid",
22
+ "leaky_relu", "gelu", "swish", "silu", "elu", "prelu", "relu6", "hardtanh",
23
+ "hardsigmoid", "hardswish", "mish", "softplus", "softsign", "tanhshrink", "threshold".
24
+ Defaults to "relu".
25
+ dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
26
+ use_batch_norm (bool, optional): Whether to use batch normalization. Defaults to True.
27
+ tie_weights (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
28
+ reconstruction_loss (str, optional): Type of reconstruction loss. Options: "mse", "bce", "l1",
29
+ "huber", "smooth_l1", "kl_div", "cosine", "focal", "dice", "tversky", "ssim", "perceptual".
30
+ Defaults to "mse".
31
+ autoencoder_type (str, optional): Type of autoencoder architecture. Options: "classic",
32
+ "variational", "beta_vae", "denoising", "sparse", "contractive", "recurrent". Defaults to "classic".
33
+ beta (float, optional): Beta parameter for beta-VAE. Defaults to 1.0.
34
+ temperature (float, optional): Temperature parameter for Gumbel softmax or other operations. Defaults to 1.0.
35
+ noise_factor (float, optional): Noise factor for denoising autoencoders. Defaults to 0.1.
36
+ rnn_type (str, optional): Type of RNN cell for recurrent autoencoders. Options: "lstm", "gru", "rnn".
37
+ Defaults to "lstm".
38
+ num_layers (int, optional): Number of RNN layers for recurrent autoencoders. Defaults to 2.
39
+ bidirectional (bool, optional): Whether to use bidirectional RNN for encoding. Defaults to True.
40
+ sequence_length (int, optional): Fixed sequence length. If None, supports variable length sequences.
41
+ Defaults to None.
42
+ teacher_forcing_ratio (float, optional): Ratio of teacher forcing during training for recurrent decoders.
43
+ Defaults to 0.5.
44
+ use_learnable_preprocessing (bool, optional): Whether to use learnable preprocessing. Defaults to False.
45
+ preprocessing_type (str, optional): Type of learnable preprocessing. Options: "none", "neural_scaler",
46
+ "normalizing_flow". Defaults to "none".
47
+ preprocessing_hidden_dim (int, optional): Hidden dimension for preprocessing networks. Defaults to 64.
48
+ preprocessing_num_layers (int, optional): Number of layers in preprocessing networks. Defaults to 2.
49
+ learn_inverse_preprocessing (bool, optional): Whether to learn inverse preprocessing for reconstruction.
50
+ Defaults to True.
51
+ flow_coupling_layers (int, optional): Number of coupling layers for normalizing flows. Defaults to 4.
52
+ **kwargs: Additional keyword arguments passed to the parent class.
53
+ """
54
+
55
+ model_type = "autoencoder"
56
+
57
+ def __init__(
58
+ self,
59
+ input_dim: int = 784,
60
+ hidden_dims: List[int] = None,
61
+ latent_dim: int = 64,
62
+ activation: str = "relu",
63
+ dropout_rate: float = 0.1,
64
+ use_batch_norm: bool = True,
65
+ tie_weights: bool = False,
66
+ reconstruction_loss: str = "mse",
67
+ autoencoder_type: str = "classic",
68
+ beta: float = 1.0,
69
+ temperature: float = 1.0,
70
+ noise_factor: float = 0.1,
71
+ # Recurrent autoencoder parameters
72
+ rnn_type: str = "lstm",
73
+ num_layers: int = 2,
74
+ bidirectional: bool = True,
75
+ sequence_length: Optional[int] = None,
76
+ teacher_forcing_ratio: float = 0.5,
77
+ # Deep learning preprocessing parameters
78
+ use_learnable_preprocessing: bool = False,
79
+ preprocessing_type: str = "none",
80
+ preprocessing_hidden_dim: int = 64,
81
+ preprocessing_num_layers: int = 2,
82
+ learn_inverse_preprocessing: bool = True,
83
+ flow_coupling_layers: int = 4,
84
+ **kwargs,
85
+ ):
86
+ # Validate parameters
87
+ if hidden_dims is None:
88
+ hidden_dims = [512, 256, 128]
89
+
90
+ # Extended activation functions
91
+ valid_activations = [
92
+ "relu", "tanh", "sigmoid", "leaky_relu", "gelu", "swish", "silu",
93
+ "elu", "prelu", "relu6", "hardtanh", "hardsigmoid", "hardswish",
94
+ "mish", "softplus", "softsign", "tanhshrink", "threshold"
95
+ ]
96
+ if activation not in valid_activations:
97
+ raise ValueError(
98
+ f"`activation` must be one of {valid_activations}, got {activation}."
99
+ )
100
+
101
+ # Extended loss functions
102
+ valid_losses = [
103
+ "mse", "bce", "l1", "huber", "smooth_l1", "kl_div", "cosine",
104
+ "focal", "dice", "tversky", "ssim", "perceptual"
105
+ ]
106
+ if reconstruction_loss not in valid_losses:
107
+ raise ValueError(
108
+ f"`reconstruction_loss` must be one of {valid_losses}, got {reconstruction_loss}."
109
+ )
110
+
111
+ # Autoencoder types
112
+ valid_types = ["classic", "variational", "beta_vae", "denoising", "sparse", "contractive", "recurrent"]
113
+ if autoencoder_type not in valid_types:
114
+ raise ValueError(
115
+ f"`autoencoder_type` must be one of {valid_types}, got {autoencoder_type}."
116
+ )
117
+
118
+ # RNN types for recurrent autoencoders
119
+ valid_rnn_types = ["lstm", "gru", "rnn"]
120
+ if rnn_type not in valid_rnn_types:
121
+ raise ValueError(
122
+ f"`rnn_type` must be one of {valid_rnn_types}, got {rnn_type}."
123
+ )
124
+
125
+ if not (0.0 <= dropout_rate <= 1.0):
126
+ raise ValueError(f"`dropout_rate` must be between 0.0 and 1.0, got {dropout_rate}.")
127
+
128
+ if input_dim <= 0:
129
+ raise ValueError(f"`input_dim` must be positive, got {input_dim}.")
130
+
131
+ if latent_dim <= 0:
132
+ raise ValueError(f"`latent_dim` must be positive, got {latent_dim}.")
133
+
134
+ if not all(dim > 0 for dim in hidden_dims):
135
+ raise ValueError("All dimensions in `hidden_dims` must be positive.")
136
+
137
+ if beta <= 0:
138
+ raise ValueError(f"`beta` must be positive, got {beta}.")
139
+
140
+ if num_layers <= 0:
141
+ raise ValueError(f"`num_layers` must be positive, got {num_layers}.")
142
+
143
+ if not (0.0 <= teacher_forcing_ratio <= 1.0):
144
+ raise ValueError(f"`teacher_forcing_ratio` must be between 0.0 and 1.0, got {teacher_forcing_ratio}.")
145
+
146
+ if sequence_length is not None and sequence_length <= 0:
147
+ raise ValueError(f"`sequence_length` must be positive when specified, got {sequence_length}.")
148
+
149
+ # Preprocessing validation
150
+ valid_preprocessing = ["none", "neural_scaler", "normalizing_flow"]
151
+ if preprocessing_type not in valid_preprocessing:
152
+ raise ValueError(
153
+ f"`preprocessing_type` must be one of {valid_preprocessing}, got {preprocessing_type}."
154
+ )
155
+
156
+ if preprocessing_hidden_dim <= 0:
157
+ raise ValueError(f"`preprocessing_hidden_dim` must be positive, got {preprocessing_hidden_dim}.")
158
+
159
+ if preprocessing_num_layers <= 0:
160
+ raise ValueError(f"`preprocessing_num_layers` must be positive, got {preprocessing_num_layers}.")
161
+
162
+ if flow_coupling_layers <= 0:
163
+ raise ValueError(f"`flow_coupling_layers` must be positive, got {flow_coupling_layers}.")
164
+
165
+ # Set configuration attributes
166
+ self.input_dim = input_dim
167
+ self.hidden_dims = hidden_dims
168
+ self.latent_dim = latent_dim
169
+ self.activation = activation
170
+ self.dropout_rate = dropout_rate
171
+ self.use_batch_norm = use_batch_norm
172
+ self.tie_weights = tie_weights
173
+ self.reconstruction_loss = reconstruction_loss
174
+ self.autoencoder_type = autoencoder_type
175
+ self.beta = beta
176
+ self.temperature = temperature
177
+ self.noise_factor = noise_factor
178
+ self.rnn_type = rnn_type
179
+ self.num_layers = num_layers
180
+ self.bidirectional = bidirectional
181
+ self.sequence_length = sequence_length
182
+ self.teacher_forcing_ratio = teacher_forcing_ratio
183
+ self.use_learnable_preprocessing = use_learnable_preprocessing
184
+ self.preprocessing_type = preprocessing_type
185
+ self.preprocessing_hidden_dim = preprocessing_hidden_dim
186
+ self.preprocessing_num_layers = preprocessing_num_layers
187
+ self.learn_inverse_preprocessing = learn_inverse_preprocessing
188
+ self.flow_coupling_layers = flow_coupling_layers
189
+
190
+ # Call parent constructor
191
+ super().__init__(**kwargs)
192
+
193
+ @property
194
+ def decoder_dims(self) -> List[int]:
195
+ """Get decoder dimensions (reverse of encoder hidden dims)."""
196
+ return list(reversed(self.hidden_dims))
197
+
198
+ @property
199
+ def is_variational(self) -> bool:
200
+ """Check if this is a variational autoencoder."""
201
+ return self.autoencoder_type in ["variational", "beta_vae"]
202
+
203
+ @property
204
+ def is_denoising(self) -> bool:
205
+ """Check if this is a denoising autoencoder."""
206
+ return self.autoencoder_type == "denoising"
207
+
208
+ @property
209
+ def is_sparse(self) -> bool:
210
+ """Check if this is a sparse autoencoder."""
211
+ return self.autoencoder_type == "sparse"
212
+
213
+ @property
214
+ def is_contractive(self) -> bool:
215
+ """Check if this is a contractive autoencoder."""
216
+ return self.autoencoder_type == "contractive"
217
+
218
+ @property
219
+ def is_recurrent(self) -> bool:
220
+ """Check if this is a recurrent autoencoder."""
221
+ return self.autoencoder_type == "recurrent"
222
+
223
+ @property
224
+ def rnn_hidden_size(self) -> int:
225
+ """Get the RNN hidden size (same as latent_dim for recurrent AE)."""
226
+ return self.latent_dim
227
+
228
+ @property
229
+ def rnn_output_size(self) -> int:
230
+ """Get the RNN output size considering bidirectionality."""
231
+ return self.latent_dim * (2 if self.bidirectional else 1)
232
+
233
+ @property
234
+ def has_preprocessing(self) -> bool:
235
+ """Check if learnable preprocessing is enabled."""
236
+ return self.use_learnable_preprocessing and self.preprocessing_type != "none"
237
+
238
+ @property
239
+ def is_neural_scaler(self) -> bool:
240
+ """Check if using neural scaler preprocessing."""
241
+ return self.preprocessing_type == "neural_scaler"
242
+
243
+ @property
244
+ def is_normalizing_flow(self) -> bool:
245
+ """Check if using normalizing flow preprocessing."""
246
+ return self.preprocessing_type == "normalizing_flow"
247
+
248
+ def to_dict(self):
249
+ """
250
+ Serializes this instance to a Python dictionary.
251
+ """
252
+ output = super().to_dict()
253
+ return output
modeling_autoencoder.py ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Autoencoder model for Hugging Face Transformers.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, Union, Dict, Any, List
9
+ from dataclasses import dataclass
10
+ import random
11
+
12
+ from transformers import PreTrainedModel
13
+ from transformers.modeling_outputs import BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from configuration_autoencoder import AutoencoderConfig
17
+
18
+
19
+ class NeuralScaler(nn.Module):
20
+ """Learnable alternative to StandardScaler using neural networks."""
21
+
22
+ def __init__(self, config: AutoencoderConfig):
23
+ super().__init__()
24
+ self.config = config
25
+ input_dim = config.input_dim
26
+ hidden_dim = config.preprocessing_hidden_dim
27
+
28
+ # Networks to learn data-dependent statistics
29
+ self.mean_estimator = nn.Sequential(
30
+ nn.Linear(input_dim, hidden_dim),
31
+ nn.ReLU(),
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ nn.ReLU(),
34
+ nn.Linear(hidden_dim, input_dim)
35
+ )
36
+
37
+ self.std_estimator = nn.Sequential(
38
+ nn.Linear(input_dim, hidden_dim),
39
+ nn.ReLU(),
40
+ nn.Linear(hidden_dim, hidden_dim),
41
+ nn.ReLU(),
42
+ nn.Linear(hidden_dim, input_dim),
43
+ nn.Softplus() # Ensure positive standard deviation
44
+ )
45
+
46
+ # Learnable affine transformation parameters
47
+ self.weight = nn.Parameter(torch.ones(input_dim))
48
+ self.bias = nn.Parameter(torch.zeros(input_dim))
49
+
50
+ # Running statistics for inference (like BatchNorm)
51
+ self.register_buffer('running_mean', torch.zeros(input_dim))
52
+ self.register_buffer('running_std', torch.ones(input_dim))
53
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
54
+
55
+ # Momentum for running statistics
56
+ self.momentum = 0.1
57
+
58
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ """
60
+ Forward pass through neural scaler.
61
+
62
+ Args:
63
+ x: Input tensor (2D or 3D)
64
+ inverse: Whether to apply inverse transformation
65
+
66
+ Returns:
67
+ Tuple of (transformed_tensor, regularization_loss)
68
+ """
69
+ if inverse:
70
+ return self._inverse_transform(x)
71
+
72
+ # Handle both 2D and 3D tensors
73
+ original_shape = x.shape
74
+ if x.dim() == 3:
75
+ # Reshape (batch, seq, features) -> (batch*seq, features)
76
+ x = x.view(-1, x.size(-1))
77
+
78
+ if self.training:
79
+ # Training mode: learn statistics from current batch
80
+ batch_mean = x.mean(dim=0, keepdim=True)
81
+ batch_std = x.std(dim=0, keepdim=True)
82
+
83
+ # Learn data-dependent adjustments
84
+ learned_mean_adj = self.mean_estimator(batch_mean)
85
+ learned_std_adj = self.std_estimator(batch_std)
86
+
87
+ # Combine batch statistics with learned adjustments
88
+ effective_mean = batch_mean + learned_mean_adj
89
+ effective_std = batch_std + learned_std_adj + 1e-8
90
+
91
+ # Update running statistics
92
+ with torch.no_grad():
93
+ self.num_batches_tracked += 1
94
+ if self.num_batches_tracked == 1:
95
+ self.running_mean.copy_(batch_mean.squeeze())
96
+ self.running_std.copy_(batch_std.squeeze())
97
+ else:
98
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
99
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
100
+ else:
101
+ # Inference mode: use running statistics
102
+ effective_mean = self.running_mean.unsqueeze(0)
103
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
104
+
105
+ # Normalize
106
+ normalized = (x - effective_mean) / effective_std
107
+
108
+ # Apply learnable affine transformation
109
+ transformed = normalized * self.weight + self.bias
110
+
111
+ # Reshape back to original shape if needed
112
+ if len(original_shape) == 3:
113
+ transformed = transformed.view(original_shape)
114
+
115
+ # Regularization loss to encourage meaningful learning
116
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
117
+
118
+ return transformed, reg_loss
119
+
120
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """Apply inverse transformation to get back original scale."""
122
+ if not self.config.learn_inverse_preprocessing:
123
+ return x, torch.tensor(0.0, device=x.device)
124
+
125
+ # Handle both 2D and 3D tensors
126
+ original_shape = x.shape
127
+ if x.dim() == 3:
128
+ # Reshape (batch, seq, features) -> (batch*seq, features)
129
+ x = x.view(-1, x.size(-1))
130
+
131
+ # Reverse affine transformation
132
+ x = (x - self.bias) / (self.weight + 1e-8)
133
+
134
+ # Reverse normalization using running statistics
135
+ effective_mean = self.running_mean.unsqueeze(0)
136
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
137
+ x = x * effective_std + effective_mean
138
+
139
+ # Reshape back to original shape if needed
140
+ if len(original_shape) == 3:
141
+ x = x.view(original_shape)
142
+
143
+ return x, torch.tensor(0.0, device=x.device)
144
+
145
+
146
+ class CouplingLayer(nn.Module):
147
+ """Coupling layer for normalizing flows."""
148
+
149
+ def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"):
150
+ super().__init__()
151
+ self.input_dim = input_dim
152
+ self.hidden_dim = hidden_dim
153
+
154
+ # Create mask for coupling
155
+ if mask_type == "alternating":
156
+ self.register_buffer('mask', torch.arange(input_dim) % 2)
157
+ elif mask_type == "half":
158
+ mask = torch.zeros(input_dim)
159
+ mask[:input_dim // 2] = 1
160
+ self.register_buffer('mask', mask)
161
+ else:
162
+ raise ValueError(f"Unknown mask type: {mask_type}")
163
+
164
+ # Scale and translation networks
165
+ masked_dim = int(self.mask.sum().item())
166
+ unmasked_dim = input_dim - masked_dim
167
+
168
+ self.scale_net = nn.Sequential(
169
+ nn.Linear(masked_dim, hidden_dim),
170
+ nn.ReLU(),
171
+ nn.Linear(hidden_dim, hidden_dim),
172
+ nn.ReLU(),
173
+ nn.Linear(hidden_dim, unmasked_dim),
174
+ nn.Tanh() # Bounded output for stability
175
+ )
176
+
177
+ self.translate_net = nn.Sequential(
178
+ nn.Linear(masked_dim, hidden_dim),
179
+ nn.ReLU(),
180
+ nn.Linear(hidden_dim, hidden_dim),
181
+ nn.ReLU(),
182
+ nn.Linear(hidden_dim, unmasked_dim)
183
+ )
184
+
185
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
186
+ """
187
+ Forward pass through coupling layer.
188
+
189
+ Args:
190
+ x: Input tensor
191
+ inverse: Whether to apply inverse transformation
192
+
193
+ Returns:
194
+ Tuple of (transformed_tensor, log_determinant)
195
+ """
196
+ mask = self.mask.bool()
197
+ x_masked = x[:, mask]
198
+ x_unmasked = x[:, ~mask]
199
+
200
+ # Compute scale and translation
201
+ s = self.scale_net(x_masked)
202
+ t = self.translate_net(x_masked)
203
+
204
+ if not inverse:
205
+ # Forward transformation
206
+ y_unmasked = x_unmasked * torch.exp(s) + t
207
+ log_det = s.sum(dim=1)
208
+ else:
209
+ # Inverse transformation
210
+ y_unmasked = (x_unmasked - t) * torch.exp(-s)
211
+ log_det = -s.sum(dim=1)
212
+
213
+ # Reconstruct output
214
+ y = torch.zeros_like(x)
215
+ y[:, mask] = x_masked
216
+ y[:, ~mask] = y_unmasked
217
+
218
+ return y, log_det
219
+
220
+
221
+ class NormalizingFlowPreprocessor(nn.Module):
222
+ """Normalizing flow for learnable data preprocessing."""
223
+
224
+ def __init__(self, config: AutoencoderConfig):
225
+ super().__init__()
226
+ self.config = config
227
+ input_dim = config.input_dim
228
+ hidden_dim = config.preprocessing_hidden_dim
229
+ num_layers = config.flow_coupling_layers
230
+
231
+ # Create coupling layers with alternating masks
232
+ self.layers = nn.ModuleList()
233
+ for i in range(num_layers):
234
+ mask_type = "alternating" if i % 2 == 0 else "half"
235
+ self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type))
236
+
237
+ # Optional: Add batch normalization between layers
238
+ if config.use_batch_norm:
239
+ self.batch_norms = nn.ModuleList([
240
+ nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)
241
+ ])
242
+ else:
243
+ self.batch_norms = None
244
+
245
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ """
247
+ Forward pass through normalizing flow.
248
+
249
+ Args:
250
+ x: Input tensor (2D or 3D)
251
+ inverse: Whether to apply inverse transformation
252
+
253
+ Returns:
254
+ Tuple of (transformed_tensor, total_log_determinant)
255
+ """
256
+ # Handle both 2D and 3D tensors
257
+ original_shape = x.shape
258
+ if x.dim() == 3:
259
+ # Reshape (batch, seq, features) -> (batch*seq, features)
260
+ x = x.view(-1, x.size(-1))
261
+
262
+ log_det_total = torch.zeros(x.size(0), device=x.device)
263
+
264
+ if not inverse:
265
+ # Forward pass
266
+ for i, layer in enumerate(self.layers):
267
+ x, log_det = layer(x, inverse=False)
268
+ log_det_total += log_det
269
+
270
+ # Apply batch normalization (except for last layer)
271
+ if self.batch_norms and i < len(self.layers) - 1:
272
+ x = self.batch_norms[i](x)
273
+ else:
274
+ # Inverse pass
275
+ for i, layer in enumerate(reversed(self.layers)):
276
+ # Reverse batch normalization (except for first layer in reverse)
277
+ if self.batch_norms and i > 0:
278
+ # Note: This is approximate inverse of batch norm
279
+ bn_idx = len(self.layers) - 1 - i
280
+ x = self.batch_norms[bn_idx](x)
281
+
282
+ x, log_det = layer(x, inverse=True)
283
+ log_det_total += log_det
284
+
285
+ # Reshape back to original shape if needed
286
+ if len(original_shape) == 3:
287
+ x = x.view(original_shape)
288
+
289
+ # Convert log determinant to regularization loss
290
+ # Encourage the flow to preserve information (log_det close to 0)
291
+ reg_loss = 0.01 * log_det_total.abs().mean()
292
+
293
+ return x, reg_loss
294
+
295
+
296
+ class LearnablePreprocessor(nn.Module):
297
+ """Unified interface for learnable preprocessing methods."""
298
+
299
+ def __init__(self, config: AutoencoderConfig):
300
+ super().__init__()
301
+ self.config = config
302
+
303
+ if not config.has_preprocessing:
304
+ self.preprocessor = nn.Identity()
305
+ elif config.is_neural_scaler:
306
+ self.preprocessor = NeuralScaler(config)
307
+ elif config.is_normalizing_flow:
308
+ self.preprocessor = NormalizingFlowPreprocessor(config)
309
+ else:
310
+ raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
311
+
312
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
313
+ """
314
+ Apply preprocessing transformation.
315
+
316
+ Args:
317
+ x: Input tensor
318
+ inverse: Whether to apply inverse transformation
319
+
320
+ Returns:
321
+ Tuple of (transformed_tensor, regularization_loss)
322
+ """
323
+ if isinstance(self.preprocessor, nn.Identity):
324
+ return x, torch.tensor(0.0, device=x.device)
325
+
326
+ return self.preprocessor(x, inverse=inverse)
327
+
328
+
329
+ @dataclass
330
+ class AutoencoderOutput(ModelOutput):
331
+ """
332
+ Output type of AutoencoderModel.
333
+
334
+ Args:
335
+ last_hidden_state (torch.FloatTensor): The latent representation of the input.
336
+ reconstructed (torch.FloatTensor, optional): The reconstructed input.
337
+ hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers.
338
+ attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder.
339
+ preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing.
340
+ """
341
+
342
+ last_hidden_state: torch.FloatTensor = None
343
+ reconstructed: Optional[torch.FloatTensor] = None
344
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
345
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
346
+ preprocessing_loss: Optional[torch.FloatTensor] = None
347
+
348
+
349
+ @dataclass
350
+ class AutoencoderForReconstructionOutput(ModelOutput):
351
+ """
352
+ Output type of AutoencoderForReconstruction.
353
+
354
+ Args:
355
+ loss (torch.FloatTensor, optional): The reconstruction loss.
356
+ reconstructed (torch.FloatTensor): The reconstructed input.
357
+ last_hidden_state (torch.FloatTensor): The latent representation.
358
+ hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers.
359
+ preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing.
360
+ """
361
+
362
+ loss: Optional[torch.FloatTensor] = None
363
+ reconstructed: torch.FloatTensor = None
364
+ last_hidden_state: torch.FloatTensor = None
365
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
366
+ preprocessing_loss: Optional[torch.FloatTensor] = None
367
+
368
+
369
+ class AutoencoderEncoder(nn.Module):
370
+ """Encoder part of the autoencoder."""
371
+
372
+ def __init__(self, config: AutoencoderConfig):
373
+ super().__init__()
374
+ self.config = config
375
+
376
+ # Build encoder layers
377
+ layers = []
378
+ input_dim = config.input_dim
379
+
380
+ for hidden_dim in config.hidden_dims:
381
+ layers.append(nn.Linear(input_dim, hidden_dim))
382
+
383
+ if config.use_batch_norm:
384
+ layers.append(nn.BatchNorm1d(hidden_dim))
385
+
386
+ layers.append(self._get_activation(config.activation))
387
+
388
+ if config.dropout_rate > 0:
389
+ layers.append(nn.Dropout(config.dropout_rate))
390
+
391
+ input_dim = hidden_dim
392
+
393
+ self.encoder = nn.Sequential(*layers)
394
+
395
+ # For variational autoencoders, we need separate layers for mean and log variance
396
+ if config.is_variational:
397
+ self.fc_mu = nn.Linear(input_dim, config.latent_dim)
398
+ self.fc_logvar = nn.Linear(input_dim, config.latent_dim)
399
+ else:
400
+ # Standard encoder output
401
+ self.fc_out = nn.Linear(input_dim, config.latent_dim)
402
+
403
+ def _get_activation(self, activation: str) -> nn.Module:
404
+ """Get activation function by name."""
405
+ activations = {
406
+ "relu": nn.ReLU(),
407
+ "tanh": nn.Tanh(),
408
+ "sigmoid": nn.Sigmoid(),
409
+ "leaky_relu": nn.LeakyReLU(),
410
+ "gelu": nn.GELU(),
411
+ "swish": nn.SiLU(),
412
+ "silu": nn.SiLU(),
413
+ "elu": nn.ELU(),
414
+ "prelu": nn.PReLU(),
415
+ "relu6": nn.ReLU6(),
416
+ "hardtanh": nn.Hardtanh(),
417
+ "hardsigmoid": nn.Hardsigmoid(),
418
+ "hardswish": nn.Hardswish(),
419
+ "mish": nn.Mish(),
420
+ "softplus": nn.Softplus(),
421
+ "softsign": nn.Softsign(),
422
+ "tanhshrink": nn.Tanhshrink(),
423
+ "threshold": nn.Threshold(threshold=0.1, value=0),
424
+ }
425
+ return activations[activation]
426
+
427
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
428
+ """Forward pass through encoder."""
429
+ # Add noise for denoising autoencoders
430
+ if self.config.is_denoising and self.training:
431
+ noise = torch.randn_like(x) * self.config.noise_factor
432
+ x = x + noise
433
+
434
+ encoded = self.encoder(x)
435
+
436
+ if self.config.is_variational:
437
+ # Variational autoencoder: return mean, log variance, and sampled latent
438
+ mu = self.fc_mu(encoded)
439
+ logvar = self.fc_logvar(encoded)
440
+
441
+ # Reparameterization trick
442
+ if self.training:
443
+ std = torch.exp(0.5 * logvar)
444
+ eps = torch.randn_like(std)
445
+ z = mu + eps * std
446
+ else:
447
+ z = mu # Use mean during inference
448
+
449
+ return z, mu, logvar
450
+ else:
451
+ # Standard autoencoder
452
+ latent = self.fc_out(encoded)
453
+
454
+ # Add sparsity constraint for sparse autoencoders
455
+ if self.config.is_sparse and self.training:
456
+ # Apply L1 regularization to encourage sparsity
457
+ latent = F.relu(latent) # Ensure non-negative activations
458
+
459
+ return latent
460
+
461
+
462
+ class AutoencoderDecoder(nn.Module):
463
+ """Decoder part of the autoencoder."""
464
+
465
+ def __init__(self, config: AutoencoderConfig):
466
+ super().__init__()
467
+ self.config = config
468
+
469
+ # Build decoder layers (reverse of encoder)
470
+ layers = []
471
+ input_dim = config.latent_dim
472
+ decoder_dims = config.decoder_dims + [config.input_dim]
473
+
474
+ for i, hidden_dim in enumerate(decoder_dims):
475
+ layers.append(nn.Linear(input_dim, hidden_dim))
476
+
477
+ # Don't add batch norm, activation, or dropout to the final layer
478
+ if i < len(decoder_dims) - 1:
479
+ if config.use_batch_norm:
480
+ layers.append(nn.BatchNorm1d(hidden_dim))
481
+
482
+ layers.append(self._get_activation(config.activation))
483
+
484
+ if config.dropout_rate > 0:
485
+ layers.append(nn.Dropout(config.dropout_rate))
486
+ else:
487
+ # Final layer - add appropriate activation based on reconstruction loss
488
+ if config.reconstruction_loss == "bce":
489
+ layers.append(nn.Sigmoid())
490
+
491
+ input_dim = hidden_dim
492
+
493
+ self.decoder = nn.Sequential(*layers)
494
+
495
+ def _get_activation(self, activation: str) -> nn.Module:
496
+ """Get activation function by name."""
497
+ activations = {
498
+ "relu": nn.ReLU(),
499
+ "tanh": nn.Tanh(),
500
+ "sigmoid": nn.Sigmoid(),
501
+ "leaky_relu": nn.LeakyReLU(),
502
+ "gelu": nn.GELU(),
503
+ "swish": nn.SiLU(),
504
+ "silu": nn.SiLU(),
505
+ "elu": nn.ELU(),
506
+ "prelu": nn.PReLU(),
507
+ "relu6": nn.ReLU6(),
508
+ "hardtanh": nn.Hardtanh(),
509
+ "hardsigmoid": nn.Hardsigmoid(),
510
+ "hardswish": nn.Hardswish(),
511
+ "mish": nn.Mish(),
512
+ "softplus": nn.Softplus(),
513
+ "softsign": nn.Softsign(),
514
+ "tanhshrink": nn.Tanhshrink(),
515
+ "threshold": nn.Threshold(threshold=0.1, value=0),
516
+ }
517
+ return activations[activation]
518
+
519
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
520
+ """Forward pass through decoder."""
521
+ return self.decoder(x)
522
+
523
+
524
+ class RecurrentEncoder(nn.Module):
525
+ """Recurrent encoder for sequence data."""
526
+
527
+ def __init__(self, config: AutoencoderConfig):
528
+ super().__init__()
529
+ self.config = config
530
+
531
+ # Get RNN class
532
+ if config.rnn_type == "lstm":
533
+ rnn_class = nn.LSTM
534
+ elif config.rnn_type == "gru":
535
+ rnn_class = nn.GRU
536
+ elif config.rnn_type == "rnn":
537
+ rnn_class = nn.RNN
538
+ else:
539
+ raise ValueError(f"Unknown RNN type: {config.rnn_type}")
540
+
541
+ # Create RNN layers
542
+ self.rnn = rnn_class(
543
+ input_size=config.input_dim,
544
+ hidden_size=config.latent_dim,
545
+ num_layers=config.num_layers,
546
+ batch_first=True,
547
+ dropout=config.dropout_rate if config.num_layers > 1 else 0,
548
+ bidirectional=config.bidirectional
549
+ )
550
+
551
+ # Projection layer for bidirectional RNN
552
+ if config.bidirectional:
553
+ self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim)
554
+ else:
555
+ self.projection = None
556
+
557
+ # Batch normalization
558
+ if config.use_batch_norm:
559
+ self.batch_norm = nn.BatchNorm1d(config.latent_dim)
560
+ else:
561
+ self.batch_norm = None
562
+
563
+ # Dropout
564
+ if config.dropout_rate > 0:
565
+ self.dropout = nn.Dropout(config.dropout_rate)
566
+ else:
567
+ self.dropout = None
568
+
569
+ def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
570
+ """
571
+ Forward pass through recurrent encoder.
572
+
573
+ Args:
574
+ x: Input tensor of shape (batch_size, seq_len, input_dim)
575
+ lengths: Sequence lengths for packed sequences (optional)
576
+
577
+ Returns:
578
+ Encoded representation or tuple for VAE
579
+ """
580
+ batch_size, seq_len, _ = x.shape
581
+
582
+ # Add noise for denoising autoencoders
583
+ if self.config.is_denoising and self.training:
584
+ noise = torch.randn_like(x) * self.config.noise_factor
585
+ x = x + noise
586
+
587
+ # Pack sequences if lengths provided
588
+ if lengths is not None:
589
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
590
+
591
+ # RNN forward pass
592
+ if self.config.rnn_type == "lstm":
593
+ output, (hidden, cell) = self.rnn(x)
594
+ else:
595
+ output, hidden = self.rnn(x)
596
+ cell = None
597
+
598
+ # Unpack if necessary
599
+ if lengths is not None:
600
+ output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
601
+
602
+ # Use last hidden state as encoding
603
+ if self.config.bidirectional:
604
+ # Concatenate forward and backward hidden states
605
+ hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim)
606
+ hidden = hidden[-1] # Take last layer
607
+ hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) # Concatenate directions
608
+
609
+ # Project to latent dimension
610
+ if self.projection:
611
+ hidden = self.projection(hidden)
612
+ else:
613
+ hidden = hidden[-1] # Take last layer
614
+
615
+ # Apply batch normalization
616
+ if self.batch_norm:
617
+ hidden = self.batch_norm(hidden)
618
+
619
+ # Apply dropout
620
+ if self.dropout and self.training:
621
+ hidden = self.dropout(hidden)
622
+
623
+ # Handle variational encoding
624
+ if self.config.is_variational:
625
+ # Split hidden into mean and log variance
626
+ mu = hidden[:, :self.config.latent_dim // 2]
627
+ logvar = hidden[:, self.config.latent_dim // 2:]
628
+
629
+ # Reparameterization trick
630
+ if self.training:
631
+ std = torch.exp(0.5 * logvar)
632
+ eps = torch.randn_like(std)
633
+ z = mu + eps * std
634
+ else:
635
+ z = mu
636
+
637
+ return z, mu, logvar
638
+ else:
639
+ return hidden
640
+
641
+
642
+ class RecurrentDecoder(nn.Module):
643
+ """Recurrent decoder for sequence data."""
644
+
645
+ def __init__(self, config: AutoencoderConfig):
646
+ super().__init__()
647
+ self.config = config
648
+
649
+ # Get RNN class
650
+ if config.rnn_type == "lstm":
651
+ rnn_class = nn.LSTM
652
+ elif config.rnn_type == "gru":
653
+ rnn_class = nn.GRU
654
+ elif config.rnn_type == "rnn":
655
+ rnn_class = nn.RNN
656
+ else:
657
+ raise ValueError(f"Unknown RNN type: {config.rnn_type}")
658
+
659
+ # Create RNN layers
660
+ self.rnn = rnn_class(
661
+ input_size=config.latent_dim,
662
+ hidden_size=config.latent_dim,
663
+ num_layers=config.num_layers,
664
+ batch_first=True,
665
+ dropout=config.dropout_rate if config.num_layers > 1 else 0,
666
+ bidirectional=False # Decoder is always unidirectional
667
+ )
668
+
669
+ # Output projection
670
+ self.output_projection = nn.Linear(config.latent_dim, config.input_dim)
671
+
672
+ # Batch normalization
673
+ if config.use_batch_norm:
674
+ self.batch_norm = nn.BatchNorm1d(config.latent_dim)
675
+ else:
676
+ self.batch_norm = None
677
+
678
+ # Dropout
679
+ if config.dropout_rate > 0:
680
+ self.dropout = nn.Dropout(config.dropout_rate)
681
+ else:
682
+ self.dropout = None
683
+
684
+ def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor:
685
+ """
686
+ Forward pass through recurrent decoder.
687
+
688
+ Args:
689
+ z: Latent representation of shape (batch_size, latent_dim)
690
+ target_length: Length of sequence to generate
691
+ target_sequence: Target sequence for teacher forcing (optional)
692
+
693
+ Returns:
694
+ Decoded sequence of shape (batch_size, seq_len, input_dim)
695
+ """
696
+ batch_size = z.size(0)
697
+ device = z.device
698
+
699
+ # Initialize hidden state with latent representation
700
+ if self.config.rnn_type == "lstm":
701
+ h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1)
702
+ c_0 = torch.zeros_like(h_0)
703
+ hidden = (h_0, c_0)
704
+ else:
705
+ hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1)
706
+
707
+ outputs = []
708
+
709
+ # Initialize input (can be learned or zero)
710
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
711
+
712
+ for t in range(target_length):
713
+ # Teacher forcing decision
714
+ use_teacher_forcing = (target_sequence is not None and
715
+ self.training and
716
+ random.random() < self.config.teacher_forcing_ratio)
717
+
718
+ if use_teacher_forcing and t > 0:
719
+ # Use previous target as input
720
+ current_input = target_sequence[:, t-1:t, :]
721
+ # Project to latent dimension if needed
722
+ if current_input.size(-1) != self.config.latent_dim:
723
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
724
+
725
+ # RNN forward step
726
+ if self.config.rnn_type == "lstm":
727
+ output, hidden = self.rnn(current_input, hidden)
728
+ else:
729
+ output, hidden = self.rnn(current_input, hidden)
730
+
731
+ # Apply batch normalization and dropout
732
+ output_flat = output.squeeze(1) # Remove sequence dimension
733
+
734
+ if self.batch_norm:
735
+ output_flat = self.batch_norm(output_flat)
736
+
737
+ if self.dropout and self.training:
738
+ output_flat = self.dropout(output_flat)
739
+
740
+ # Project to output dimension
741
+ step_output = self.output_projection(output_flat)
742
+ outputs.append(step_output.unsqueeze(1))
743
+
744
+ # Use output as next input (for non-teacher forcing)
745
+ if not use_teacher_forcing:
746
+ # Project output back to latent dimension for next step
747
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
748
+
749
+ # Concatenate all outputs
750
+ return torch.cat(outputs, dim=1)
751
+
752
+
753
+ class AutoencoderModel(PreTrainedModel):
754
+ """
755
+ The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top.
756
+
757
+ This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the
758
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
759
+ etc.)
760
+
761
+ This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the
762
+ PyTorch documentation for all matter related to general usage and behavior.
763
+ """
764
+
765
+ config_class = AutoencoderConfig
766
+ base_model_prefix = "autoencoder"
767
+ supports_gradient_checkpointing = False
768
+
769
+ def __init__(self, config: AutoencoderConfig):
770
+ super().__init__(config)
771
+ self.config = config
772
+
773
+ # Initialize learnable preprocessing
774
+ if config.has_preprocessing:
775
+ self.preprocessor = LearnablePreprocessor(config)
776
+ else:
777
+ self.preprocessor = None
778
+
779
+ # Initialize encoder and decoder based on type
780
+ if config.is_recurrent:
781
+ self.encoder = RecurrentEncoder(config)
782
+ self.decoder = RecurrentDecoder(config)
783
+ else:
784
+ self.encoder = AutoencoderEncoder(config)
785
+ self.decoder = AutoencoderDecoder(config)
786
+
787
+ # Tie weights if specified
788
+ if config.tie_weights:
789
+ self._tie_weights()
790
+
791
+ # Initialize weights
792
+ self.post_init()
793
+
794
+ def _tie_weights(self):
795
+ """Tie encoder and decoder weights (transpose relationship)."""
796
+ # This is a simplified weight tying - in practice, you might want more sophisticated tying
797
+ pass
798
+
799
+ def get_input_embeddings(self):
800
+ """Get input embeddings (not applicable for basic autoencoder)."""
801
+ return None
802
+
803
+ def set_input_embeddings(self, value):
804
+ """Set input embeddings (not applicable for basic autoencoder)."""
805
+ pass
806
+
807
+ def forward(
808
+ self,
809
+ input_values: torch.Tensor,
810
+ sequence_lengths: Optional[torch.Tensor] = None,
811
+ target_length: Optional[int] = None,
812
+ output_hidden_states: Optional[bool] = None,
813
+ return_dict: Optional[bool] = None,
814
+ ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]:
815
+ """
816
+ Forward pass through the autoencoder.
817
+
818
+ Args:
819
+ input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type:
820
+ - Standard: (batch_size, input_dim)
821
+ - Recurrent: (batch_size, seq_len, input_dim)
822
+ sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE.
823
+ target_length (int, optional): Target sequence length for recurrent decoder.
824
+ output_hidden_states (bool, optional): Whether to return hidden states.
825
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
826
+
827
+ Returns:
828
+ AutoencoderOutput or tuple: The model outputs.
829
+ """
830
+ output_hidden_states = (
831
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
832
+ )
833
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
834
+
835
+ # Apply learnable preprocessing
836
+ preprocessing_loss = torch.tensor(0.0, device=input_values.device)
837
+ if self.preprocessor is not None:
838
+ input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False)
839
+
840
+ # Handle different autoencoder types
841
+ if self.config.is_recurrent:
842
+ # Recurrent autoencoder
843
+ if sequence_lengths is not None:
844
+ encoder_output = self.encoder(input_values, sequence_lengths)
845
+ else:
846
+ encoder_output = self.encoder(input_values)
847
+
848
+ if self.config.is_variational:
849
+ latent, mu, logvar = encoder_output
850
+ self._mu = mu
851
+ self._logvar = logvar
852
+ else:
853
+ latent = encoder_output
854
+ self._mu = None
855
+ self._logvar = None
856
+
857
+ # Determine target length for decoder
858
+ if target_length is None:
859
+ if self.config.sequence_length is not None:
860
+ target_length = self.config.sequence_length
861
+ else:
862
+ target_length = input_values.size(1) # Use input sequence length
863
+
864
+ # Decode latent back to sequence space
865
+ reconstructed = self.decoder(latent, target_length, input_values if self.training else None)
866
+ else:
867
+ # Standard autoencoder
868
+ encoder_output = self.encoder(input_values)
869
+
870
+ if self.config.is_variational:
871
+ latent, mu, logvar = encoder_output
872
+ self._mu = mu
873
+ self._logvar = logvar
874
+ else:
875
+ latent = encoder_output
876
+ self._mu = None
877
+ self._logvar = None
878
+
879
+ # Decode latent back to input space
880
+ reconstructed = self.decoder(latent)
881
+
882
+ # Apply inverse preprocessing to reconstruction
883
+ if self.preprocessor is not None and self.config.learn_inverse_preprocessing:
884
+ reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True)
885
+ preprocessing_loss += inverse_loss
886
+
887
+ hidden_states = None
888
+ if output_hidden_states:
889
+ if self.config.is_variational:
890
+ hidden_states = (latent, mu, logvar)
891
+ else:
892
+ hidden_states = (latent,)
893
+
894
+ if not return_dict:
895
+ return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None)
896
+
897
+ return AutoencoderOutput(
898
+ last_hidden_state=latent,
899
+ reconstructed=reconstructed,
900
+ hidden_states=hidden_states,
901
+ preprocessing_loss=preprocessing_loss,
902
+ )
903
+
904
+
905
+ class AutoencoderForReconstruction(PreTrainedModel):
906
+ """
907
+ Autoencoder Model with a reconstruction head on top for reconstruction tasks.
908
+
909
+ This model inherits from PreTrainedModel and adds a reconstruction loss calculation.
910
+ """
911
+
912
+ config_class = AutoencoderConfig
913
+ base_model_prefix = "autoencoder"
914
+
915
+ def __init__(self, config: AutoencoderConfig):
916
+ super().__init__(config)
917
+ self.config = config
918
+
919
+ # Initialize the base autoencoder model
920
+ self.autoencoder = AutoencoderModel(config)
921
+
922
+ # Initialize weights
923
+ self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ """Get input embeddings."""
927
+ return self.autoencoder.get_input_embeddings()
928
+
929
+ def set_input_embeddings(self, value):
930
+ """Set input embeddings."""
931
+ self.autoencoder.set_input_embeddings(value)
932
+
933
+ def _compute_reconstruction_loss(
934
+ self,
935
+ reconstructed: torch.Tensor,
936
+ target: torch.Tensor
937
+ ) -> torch.Tensor:
938
+ """Compute reconstruction loss based on the configured loss type."""
939
+ if self.config.reconstruction_loss == "mse":
940
+ return F.mse_loss(reconstructed, target, reduction="mean")
941
+ elif self.config.reconstruction_loss == "bce":
942
+ return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean")
943
+ elif self.config.reconstruction_loss == "l1":
944
+ return F.l1_loss(reconstructed, target, reduction="mean")
945
+ elif self.config.reconstruction_loss == "huber":
946
+ return F.huber_loss(reconstructed, target, reduction="mean")
947
+ elif self.config.reconstruction_loss == "smooth_l1":
948
+ return F.smooth_l1_loss(reconstructed, target, reduction="mean")
949
+ elif self.config.reconstruction_loss == "kl_div":
950
+ return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean")
951
+ elif self.config.reconstruction_loss == "cosine":
952
+ return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean()
953
+ elif self.config.reconstruction_loss == "focal":
954
+ return self._focal_loss(reconstructed, target)
955
+ elif self.config.reconstruction_loss == "dice":
956
+ return self._dice_loss(reconstructed, target)
957
+ elif self.config.reconstruction_loss == "tversky":
958
+ return self._tversky_loss(reconstructed, target)
959
+ elif self.config.reconstruction_loss == "ssim":
960
+ return self._ssim_loss(reconstructed, target)
961
+ elif self.config.reconstruction_loss == "perceptual":
962
+ return self._perceptual_loss(reconstructed, target)
963
+ else:
964
+ raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}")
965
+
966
+ def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor:
967
+ """Compute focal loss for handling class imbalance."""
968
+ ce_loss = F.mse_loss(pred, target, reduction="none")
969
+ pt = torch.exp(-ce_loss)
970
+ focal_loss = alpha * (1 - pt) ** gamma * ce_loss
971
+ return focal_loss.mean()
972
+
973
+ def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor:
974
+ """Compute Dice loss for segmentation-like tasks."""
975
+ pred_flat = pred.view(-1)
976
+ target_flat = target.view(-1)
977
+ intersection = (pred_flat * target_flat).sum()
978
+ dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
979
+ return 1 - dice
980
+
981
+ def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor:
982
+ """Compute Tversky loss, a generalization of Dice loss."""
983
+ pred_flat = pred.view(-1)
984
+ target_flat = target.view(-1)
985
+ true_pos = (pred_flat * target_flat).sum()
986
+ false_neg = (target_flat * (1 - pred_flat)).sum()
987
+ false_pos = ((1 - target_flat) * pred_flat).sum()
988
+ tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
989
+ return 1 - tversky
990
+
991
+ def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
992
+ """Compute SSIM-based loss (simplified version)."""
993
+ # Simplified SSIM for 1D data
994
+ mu1 = pred.mean(dim=-1, keepdim=True)
995
+ mu2 = target.mean(dim=-1, keepdim=True)
996
+ sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True)
997
+ sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True)
998
+ sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True)
999
+
1000
+ c1, c2 = 0.01, 0.03
1001
+ ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2))
1002
+ return 1 - ssim.mean()
1003
+
1004
+ def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1005
+ """Compute perceptual loss (simplified version using feature differences)."""
1006
+ # For simplicity, use L2 loss on normalized features
1007
+ pred_norm = F.normalize(pred, p=2, dim=-1)
1008
+ target_norm = F.normalize(target, p=2, dim=-1)
1009
+ return F.mse_loss(pred_norm, target_norm)
1010
+
1011
+ def forward(
1012
+ self,
1013
+ input_values: torch.Tensor,
1014
+ labels: Optional[torch.Tensor] = None,
1015
+ sequence_lengths: Optional[torch.Tensor] = None,
1016
+ target_length: Optional[int] = None,
1017
+ output_hidden_states: Optional[bool] = None,
1018
+ return_dict: Optional[bool] = None,
1019
+ ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]:
1020
+ """
1021
+ Forward pass with reconstruction loss calculation.
1022
+
1023
+ Args:
1024
+ input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type:
1025
+ - Standard: (batch_size, input_dim)
1026
+ - Recurrent: (batch_size, seq_len, input_dim)
1027
+ labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values.
1028
+ sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE.
1029
+ target_length (int, optional): Target sequence length for recurrent decoder.
1030
+ output_hidden_states (bool, optional): Whether to return hidden states.
1031
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
1032
+
1033
+ Returns:
1034
+ AutoencoderForReconstructionOutput or tuple: The model outputs including loss.
1035
+ """
1036
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1037
+
1038
+ # If no labels provided, use input as target (standard autoencoder)
1039
+ if labels is None:
1040
+ labels = input_values
1041
+
1042
+ # Forward pass through autoencoder
1043
+ outputs = self.autoencoder(
1044
+ input_values=input_values,
1045
+ sequence_lengths=sequence_lengths,
1046
+ target_length=target_length,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=True,
1049
+ )
1050
+
1051
+ reconstructed = outputs.reconstructed
1052
+ latent = outputs.last_hidden_state
1053
+ hidden_states = outputs.hidden_states
1054
+
1055
+ # Compute reconstruction loss
1056
+ recon_loss = self._compute_reconstruction_loss(reconstructed, labels)
1057
+
1058
+ # Add regularization losses based on autoencoder type
1059
+ total_loss = recon_loss
1060
+
1061
+ # Add preprocessing loss if available
1062
+ if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None:
1063
+ total_loss += outputs.preprocessing_loss
1064
+
1065
+ if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None:
1066
+ # KL divergence loss for variational autoencoders
1067
+ kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp())
1068
+ kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) # Normalize by batch size and latent dim
1069
+ total_loss = recon_loss + self.config.beta * kl_loss
1070
+
1071
+ elif self.config.is_sparse:
1072
+ # Sparsity loss for sparse autoencoders
1073
+ latent = outputs.last_hidden_state
1074
+ sparsity_loss = torch.mean(torch.abs(latent)) # L1 sparsity
1075
+ total_loss = recon_loss + 0.1 * sparsity_loss # Sparsity weight
1076
+
1077
+ elif self.config.is_contractive:
1078
+ # Contractive loss - penalize large gradients of hidden representation w.r.t. input
1079
+ latent = outputs.last_hidden_state
1080
+ latent.retain_grad()
1081
+ if latent.grad is not None:
1082
+ contractive_loss = torch.sum(latent.grad ** 2)
1083
+ total_loss = recon_loss + 0.1 * contractive_loss
1084
+
1085
+ loss = total_loss
1086
+
1087
+ if not return_dict:
1088
+ output = (reconstructed, latent)
1089
+ if hidden_states is not None:
1090
+ output = output + (hidden_states,)
1091
+ return ((loss,) + output) if loss is not None else output
1092
+
1093
+ return AutoencoderForReconstructionOutput(
1094
+ loss=loss,
1095
+ reconstructed=reconstructed,
1096
+ last_hidden_state=latent,
1097
+ hidden_states=hidden_states,
1098
+ preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None,
1099
+ )
register_autoencoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Registration script for Autoencoder models with Hugging Face AutoModel framework.
3
+ """
4
+
5
+ from transformers import AutoConfig, AutoModel
6
+ from configuration_autoencoder import AutoencoderConfig
7
+ from modeling_autoencoder import AutoencoderModel, AutoencoderForReconstruction
8
+
9
+
10
+ def register_autoencoder_models():
11
+ """
12
+ Register the autoencoder models with the Hugging Face AutoModel framework.
13
+
14
+ This function registers:
15
+ - AutoencoderConfig with AutoConfig
16
+ - AutoencoderModel with AutoModel
17
+ - AutoencoderForReconstruction with AutoModel (for reconstruction tasks)
18
+
19
+ After calling this function, you can use:
20
+ - AutoConfig.from_pretrained() to load autoencoder configs
21
+ - AutoModel.from_pretrained() to load autoencoder models
22
+ """
23
+
24
+ # Register configuration
25
+ AutoConfig.register("autoencoder", AutoencoderConfig)
26
+
27
+ # Register base model
28
+ AutoModel.register(AutoencoderConfig, AutoencoderModel)
29
+
30
+ # Note: For task-specific models like AutoencoderForReconstruction,
31
+ # we would typically create a custom AutoModelForReconstruction class
32
+ # and register it separately. For now, users can import directly.
33
+
34
+ print("✅ Autoencoder models registered with Hugging Face AutoModel framework!")
35
+ print("You can now use:")
36
+ print(" - AutoConfig.from_pretrained() for configs")
37
+ print(" - AutoModel.from_pretrained() for models")
38
+ print(" - Direct imports for task-specific models")
39
+
40
+
41
+ def register_for_auto_class():
42
+ """
43
+ Register models for auto class functionality when saving/loading.
44
+
45
+ This enables the models to be automatically discovered when using
46
+ save_pretrained() and from_pretrained() methods.
47
+ """
48
+
49
+ # Register config for auto class
50
+ AutoencoderConfig.register_for_auto_class()
51
+
52
+ # Register models for auto class
53
+ AutoencoderModel.register_for_auto_class("AutoModel")
54
+ AutoencoderForReconstruction.register_for_auto_class("AutoModel")
55
+
56
+ print("✅ Models registered for auto class functionality!")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ # Register models when script is run directly
61
+ register_autoencoder_models()
62
+ register_for_auto_class()