autoencoder / template.py
AndrewMayesPrezzee
Feat - Block like transformer structure
8abd44b
raw
history blame
17 kB
"""
Ready-to-use configuration templates for the block-based Autoencoder.
These helpers demonstrate how to assemble encoder_blocks and decoder_blocks
for a variety of architectures using the new block system. Each class extends
AutoencoderConfig and can be passed directly to AutoencoderModel.
Example:
from modeling_autoencoder import AutoencoderModel
from template import ClassicAutoencoderConfig
cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
model = AutoencoderModel(cfg)
"""
from __future__ import annotations
from typing import List
# Support both package-relative and flat import
try:
from .configuration_autoencoder import (
AutoencoderConfig,
)
except Exception: # pragma: no cover
from configuration_autoencoder import (
AutoencoderConfig,
)
# ------------------------------- Helpers ------------------------------- #
def _linear_stack(input_dim: int, dims: List[int], activation: str = "relu", normalization: str = "batch", dropout: float = 0.0):
"""Build a list of Linear block dict configs mapping input_dim -> dims sequentially."""
blocks = []
prev = input_dim
for h in dims:
blocks.append({
"type": "linear",
"input_dim": prev,
"output_dim": h,
"activation": activation,
"normalization": normalization,
"dropout_rate": dropout,
"use_residual": False,
})
prev = h
return blocks
def _default_decoder(latent_dim: int, hidden: List[int], out_dim: int, activation: str = "relu", normalization: str = "batch", dropout: float = 0.0):
"""Linear decoder: latent_dim -> hidden -> out_dim (final layer identity)."""
blocks = _linear_stack(latent_dim, hidden + [out_dim], activation, normalization, dropout)
if blocks:
blocks[-1]["activation"] = "identity"
blocks[-1]["normalization"] = "none"
blocks[-1]["dropout_rate"] = 0.0
return blocks
# ---------------------------- Class-based templates ---------------------------- #
class ClassicAutoencoderConfig(AutoencoderConfig):
"""Classic dense autoencoder using Linear blocks.
Example:
cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
"""
def __init__(self, input_dim: int = 784, latent_dim: int = 64, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class VariationalAutoencoderConfig(AutoencoderConfig):
"""Variational autoencoder (MLP). Uses VariationalBlock in the model.
Example:
cfg = VariationalAutoencoderConfig(input_dim=784, latent_dim=32)
"""
def __init__(self, input_dim: int = 784, latent_dim: int = 32, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, beta: float = 1.0, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="variational",
beta=beta,
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class TransformerAutoencoderConfig(AutoencoderConfig):
"""Transformer-style autoencoder with attention encoder and MLP decoder.
Works with (batch, input_dim) or (batch, time, input_dim).
Example:
cfg = TransformerAutoencoderConfig(input_dim=256, latent_dim=128)
"""
def __init__(self, input_dim: int = 256, latent_dim: int = 128, num_layers: int = 2, num_heads: int = 4, ffn_mult: int = 4, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = []
enc.append({"type": "linear", "input_dim": input_dim, "output_dim": input_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
for _ in range(num_layers):
enc.append({"type": "attention", "input_dim": input_dim, "num_heads": num_heads, "ffn_dim": ffn_mult * input_dim, "dropout_rate": dropout})
enc.append({"type": "linear", "input_dim": input_dim, "output_dim": input_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
dec = _default_decoder(latent_dim, [input_dim], input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class RecurrentAutoencoderConfig(AutoencoderConfig):
"""Recurrent encoder (LSTM/GRU/RNN) for sequence data.
Expected input: (batch, time, input_dim). Decoder is MLP back to features per step.
Example:
cfg = RecurrentAutoencoderConfig(input_dim=128, latent_dim=64, rnn_type="lstm")
"""
def __init__(self, input_dim: int = 128, latent_dim: int = 64, rnn_type: str = "lstm", num_layers: int = 2, bidirectional: bool = False, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = [{
"type": "recurrent",
"input_dim": input_dim,
"hidden_size": latent_dim,
"num_layers": num_layers,
"rnn_type": rnn_type,
"bidirectional": bidirectional,
"dropout_rate": dropout,
"output_dim": latent_dim,
}]
dec = _default_decoder(latent_dim, [max(latent_dim, input_dim)], input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class ConvolutionalAutoencoderConfig(AutoencoderConfig):
"""1D convolutional encoder for sequence data; decoder is per-step MLP.
Expected input: (batch, time, input_dim).
Example:
cfg = ConvolutionalAutoencoderConfig(input_dim=64, conv_channels=(64, 64))
"""
def __init__(self, input_dim: int = 64, latent_dim: int = 64, conv_channels: List[int] = (64, 64), kernel_size: int = 3, activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = []
prev = input_dim
for ch in conv_channels:
enc.append({"type": "conv1d", "input_dim": prev, "output_dim": ch, "kernel_size": kernel_size, "padding": "same", "activation": activation, "normalization": norm, "dropout_rate": dropout})
prev = ch
enc.append({"type": "linear", "input_dim": prev, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
dec = _default_decoder(latent_dim, [prev], input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class ConvAttentionAutoencoderConfig(AutoencoderConfig):
"""Mixed Conv + Attention encoder for sequence data.
Example:
cfg = ConvAttentionAutoencoderConfig(input_dim=64, latent_dim=64)
"""
def __init__(self, input_dim: int = 64, latent_dim: int = 64, conv_channels: List[int] = (64,), num_heads: int = 4, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = []
prev = input_dim
for ch in conv_channels:
enc.append({"type": "conv1d", "input_dim": prev, "output_dim": ch, "kernel_size": 3, "padding": "same", "activation": activation, "normalization": norm, "dropout_rate": dropout})
prev = ch
enc.append({"type": "attention", "input_dim": prev, "num_heads": num_heads, "ffn_dim": 4 * prev, "dropout_rate": dropout})
enc.append({"type": "linear", "input_dim": prev, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
dec = _default_decoder(latent_dim, [prev], input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class LinearRecurrentAutoencoderConfig(AutoencoderConfig):
"""Linear down-projection then Recurrent encoder.
Example:
cfg = LinearRecurrentAutoencoderConfig(input_dim=256, latent_dim=64, rnn_type="gru")
"""
def __init__(self, input_dim: int = 256, latent_dim: int = 64, rnn_type: str = "gru", activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = [
{"type": "linear", "input_dim": input_dim, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout},
{"type": "recurrent", "input_dim": latent_dim, "hidden_size": latent_dim, "num_layers": 1, "rnn_type": rnn_type, "bidirectional": False, "dropout_rate": dropout, "output_dim": latent_dim},
]
dec = _default_decoder(latent_dim, [], input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class PreprocessedAutoencoderConfig(AutoencoderConfig):
"""Classic MLP AE with learnable preprocessing/inverse.
Example:
cfg = PreprocessedAutoencoderConfig(input_dim=64, preprocessing_type="neural_scaler")
"""
def __init__(self, input_dim: int = 64, latent_dim: int = 32, preprocessing_type: str = "neural_scaler", hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, list(hidden), activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(list(hidden))), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="classic",
use_learnable_preprocessing=True,
preprocessing_type=preprocessing_type,
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class BetaVariationalAutoencoderConfig(AutoencoderConfig):
"""Beta-VAE (MLP). Like VAE but with beta > 1 controlling KL weight.
Example:
cfg = BetaVariationalAutoencoderConfig(input_dim=784, latent_dim=32, beta=4.0)
"""
def __init__(self, input_dim: int = 784, latent_dim: int = 32, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, beta: float = 4.0, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="beta_vae",
beta=beta,
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class DenoisingAutoencoderConfig(AutoencoderConfig):
"""Denoising AE: adds noise during training (handled by training loop/model if supported).
Example:
cfg = DenoisingAutoencoderConfig(input_dim=128, latent_dim=32, noise_factor=0.2)
"""
def __init__(self, input_dim: int = 128, latent_dim: int = 32, hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, noise_factor: float = 0.2, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="denoising",
noise_factor=noise_factor,
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class SparseAutoencoderConfig(AutoencoderConfig):
"""Sparse AE (typical L1 activation penalty applied in training loop).
Example:
cfg = SparseAutoencoderConfig(input_dim=256, latent_dim=64)
"""
def __init__(self, input_dim: int = 256, latent_dim: int = 64, hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="sparse",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
class ContractiveAutoencoderConfig(AutoencoderConfig):
"""Contractive AE (requires Jacobian penalty in training loop).
Example:
cfg = ContractiveAutoencoderConfig(input_dim=64, latent_dim=16)
"""
def __init__(self, input_dim: int = 64, latent_dim: int = 16, hidden: List[int] = (64, 32), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
hidden = list(hidden)
norm = "batch" if use_batch_norm else "none"
enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
super().__init__(
input_dim=input_dim,
latent_dim=latent_dim,
activation=activation,
dropout_rate=dropout,
use_batch_norm=use_batch_norm,
autoencoder_type="contractive",
encoder_blocks=enc,
decoder_blocks=dec,
**kwargs,
)
__all__ = [
"ClassicAutoencoderConfig",
"VariationalAutoencoderConfig",
"TransformerAutoencoderConfig",
"RecurrentAutoencoderConfig",
"ConvolutionalAutoencoderConfig",
"ConvAttentionAutoencoderConfig",
"LinearRecurrentAutoencoderConfig",
"PreprocessedAutoencoderConfig",
"BetaVariationalAutoencoderConfig",
"DenoisingAutoencoderConfig",
"SparseAutoencoderConfig",
"ContractiveAutoencoderConfig",
]