autoencoder / preprocessing.py
AndrewMayesPrezzee
Feat - Block like transformer structure
8abd44b
raw
history blame
20.4 kB
# flake8: noqa
"""
Learnable preprocessing components for the block-based autoencoder.
Extracted from modeling_autoencoder.py to a dedicated module.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from typing import Tuple
try:
from .blocks import BaseBlock
except Exception:
from blocks import BaseBlock
import torch.nn as nn
try:
from .configuration_autoencoder import AutoencoderConfig # when loaded via HF dynamic module
except Exception:
from configuration_autoencoder import AutoencoderConfig # local usage
class NeuralScaler(nn.Module):
"""Learnable alternative to StandardScaler using neural networks."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
input_dim = config.input_dim
hidden_dim = config.preprocessing_hidden_dim
self.mean_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
)
self.std_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
)
self.weight = nn.Parameter(torch.ones(input_dim))
self.bias = nn.Parameter(torch.zeros(input_dim))
self.register_buffer("running_mean", torch.zeros(input_dim))
self.register_buffer("running_std", torch.ones(input_dim))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.momentum = 0.1
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
if inverse:
return self._inverse_transform(x)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
if self.training:
batch_mean = x.mean(dim=0, keepdim=True)
batch_std = x.std(dim=0, keepdim=True)
learned_mean_adj = self.mean_estimator(batch_mean)
learned_std_adj = self.std_estimator(batch_std)
effective_mean = batch_mean + learned_mean_adj
effective_std = batch_std + learned_std_adj + 1e-8
with torch.no_grad():
self.num_batches_tracked += 1
if self.num_batches_tracked == 1:
self.running_mean.copy_(batch_mean.squeeze())
self.running_std.copy_(batch_std.squeeze())
else:
self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
else:
effective_mean = self.running_mean.unsqueeze(0)
effective_std = self.running_std.unsqueeze(0) + 1e-8
normalized = (x - effective_mean) / effective_std
transformed = normalized * self.weight + self.bias
if len(original_shape) == 3:
transformed = transformed.view(original_shape)
reg_loss = 0.01 * (self.weight.var() + self.bias.var())
return transformed, reg_loss
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.config.learn_inverse_preprocessing:
return x, torch.tensor(0.0, device=x.device)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
x = (x - self.bias) / (self.weight + 1e-8)
effective_mean = self.running_mean.unsqueeze(0)
effective_std = self.running_std.unsqueeze(0) + 1e-8
x = x * effective_std + effective_mean
if len(original_shape) == 3:
x = x.view(original_shape)
return x, torch.tensor(0.0, device=x.device)
class LearnableMinMaxScaler(nn.Module):
"""Learnable MinMax scaler that adapts bounds during training."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
input_dim = config.input_dim
hidden_dim = config.preprocessing_hidden_dim
self.min_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
)
self.range_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
)
self.weight = nn.Parameter(torch.ones(input_dim))
self.bias = nn.Parameter(torch.zeros(input_dim))
self.register_buffer("running_min", torch.zeros(input_dim))
self.register_buffer("running_range", torch.ones(input_dim))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.momentum = 0.1
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
if inverse:
return self._inverse_transform(x)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
eps = 1e-8
if self.training:
batch_min = x.min(dim=0, keepdim=True).values
batch_max = x.max(dim=0, keepdim=True).values
batch_range = (batch_max - batch_min).clamp_min(eps)
learned_min_adj = self.min_estimator(batch_min)
learned_range_adj = self.range_estimator(batch_range)
effective_min = batch_min + learned_min_adj
effective_range = batch_range + learned_range_adj + eps
with torch.no_grad():
self.num_batches_tracked += 1
if self.num_batches_tracked == 1:
self.running_min.copy_(batch_min.squeeze())
self.running_range.copy_(batch_range.squeeze())
else:
self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum)
self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum)
else:
effective_min = self.running_min.unsqueeze(0)
effective_range = self.running_range.unsqueeze(0)
scaled = (x - effective_min) / effective_range
transformed = scaled * self.weight + self.bias
if len(original_shape) == 3:
transformed = transformed.view(original_shape)
reg_loss = 0.01 * (self.weight.var() + self.bias.var())
if self.training:
reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean()
return transformed, reg_loss
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.config.learn_inverse_preprocessing:
return x, torch.tensor(0.0, device=x.device)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
x = (x - self.bias) / (self.weight + 1e-8)
x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0)
if len(original_shape) == 3:
x = x.view(original_shape)
return x, torch.tensor(0.0, device=x.device)
class LearnableRobustScaler(nn.Module):
"""Learnable Robust scaler using median and IQR with learnable adjustments."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
input_dim = config.input_dim
hidden_dim = config.preprocessing_hidden_dim
self.median_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
)
self.iqr_estimator = nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
)
self.weight = nn.Parameter(torch.ones(input_dim))
self.bias = nn.Parameter(torch.zeros(input_dim))
self.register_buffer("running_median", torch.zeros(input_dim))
self.register_buffer("running_iqr", torch.ones(input_dim))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.momentum = 0.1
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
if inverse:
return self._inverse_transform(x)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
eps = 1e-8
if self.training:
qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0)
q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :]
iqr = (q75 - q25).clamp_min(eps)
learned_med_adj = self.median_estimator(med)
learned_iqr_adj = self.iqr_estimator(iqr)
effective_median = med + learned_med_adj
effective_iqr = iqr + learned_iqr_adj + eps
with torch.no_grad():
self.num_batches_tracked += 1
if self.num_batches_tracked == 1:
self.running_median.copy_(med.squeeze())
self.running_iqr.copy_(iqr.squeeze())
else:
self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum)
self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum)
else:
effective_median = self.running_median.unsqueeze(0)
effective_iqr = self.running_iqr.unsqueeze(0)
normalized = (x - effective_median) / effective_iqr
transformed = normalized * self.weight + self.bias
if len(original_shape) == 3:
transformed = transformed.view(original_shape)
reg_loss = 0.01 * (self.weight.var() + self.bias.var())
if self.training:
reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean()
return transformed, reg_loss
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.config.learn_inverse_preprocessing:
return x, torch.tensor(0.0, device=x.device)
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
x = (x - self.bias) / (self.weight + 1e-8)
x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0)
if len(original_shape) == 3:
x = x.view(original_shape)
return x, torch.tensor(0.0, device=x.device)
class LearnableYeoJohnsonPreprocessor(nn.Module):
"""Learnable Yeo-Johnson power transform with per-feature lambda and affine head."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
input_dim = config.input_dim
self.lmbda = nn.Parameter(torch.ones(input_dim))
self.weight = nn.Parameter(torch.ones(input_dim))
self.bias = nn.Parameter(torch.zeros(input_dim))
self.register_buffer("running_mean", torch.zeros(input_dim))
self.register_buffer("running_std", torch.ones(input_dim))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.momentum = 0.1
def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
eps = 1e-6
lmbda = lmbda.unsqueeze(0)
pos = x >= 0
if_part = torch.where(torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps)))
two_minus_lambda = 2.0 - lmbda
else_part = torch.where(torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps)))
return torch.where(pos, if_part, else_part)
def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
eps = 1e-6
lmbda = lmbda.unsqueeze(0)
pos = y >= 0
x_pos = torch.where(torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0)
two_minus_lambda = 2.0 - lmbda
x_neg = torch.where(torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y))
return torch.where(pos, x_pos, x_neg)
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
if inverse:
return self._inverse_transform(x)
orig_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
y = self._yeo_johnson(x, self.lmbda)
if self.training:
batch_mean = y.mean(dim=0, keepdim=True)
batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6)
with torch.no_grad():
self.num_batches_tracked += 1
if self.num_batches_tracked == 1:
self.running_mean.copy_(batch_mean.squeeze())
self.running_std.copy_(batch_std.squeeze())
else:
self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
mean = batch_mean
std = batch_std
else:
mean = self.running_mean.unsqueeze(0)
std = self.running_std.unsqueeze(0)
y_norm = (y - mean) / std
out = y_norm * self.weight + self.bias
if len(orig_shape) == 3:
out = out.view(orig_shape)
reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var())
return out, reg
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.config.learn_inverse_preprocessing:
return x, torch.tensor(0.0, device=x.device)
orig_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
y = (x - self.bias) / (self.weight + 1e-8)
y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0)
out = self._yeo_johnson_inverse(y, self.lmbda)
if len(orig_shape) == 3:
out = out.view(orig_shape)
return out, torch.tensor(0.0, device=x.device)
class PreprocessingBlock(BaseBlock):
"""Wraps a LearnablePreprocessor into a BaseBlock-compatible interface.
Forward returns the transformed tensor and stores the regularization loss in .reg_loss.
The inverse flag is configured at initialization to avoid leaking kwargs to other blocks.
"""
def __init__(self, config: AutoencoderConfig, inverse: bool = False, proc: Optional[LearnablePreprocessor] = None):
super().__init__()
self.proc = proc if proc is not None else LearnablePreprocessor(config)
self._output_dim = config.input_dim
self.inverse = inverse
self.reg_loss: torch.Tensor = torch.tensor(0.0)
@property
def output_dim(self) -> int:
return self._output_dim
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
y, reg = self.proc(x, inverse=self.inverse)
self.reg_loss = reg
return y
class CouplingLayer(nn.Module):
"""Coupling layer for normalizing flows."""
def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
if mask_type == "alternating":
self.register_buffer("mask", torch.arange(input_dim) % 2)
elif mask_type == "half":
mask = torch.zeros(input_dim)
mask[: input_dim // 2] = 1
self.register_buffer("mask", mask)
else:
raise ValueError(f"Unknown mask type: {mask_type}")
masked_dim = int(self.mask.sum().item())
unmasked_dim = input_dim - masked_dim
self.scale_net = nn.Sequential(
nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh()
)
self.translate_net = nn.Sequential(
nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim)
)
def forward(self, x: torch.Tensor, inverse: bool = False):
mask = self.mask.bool()
x_masked = x[:, mask]
x_unmasked = x[:, ~mask]
s = self.scale_net(x_masked)
t = self.translate_net(x_masked)
if not inverse:
y_unmasked = x_unmasked * torch.exp(s) + t
log_det = s.sum(dim=1)
else:
y_unmasked = (x_unmasked - t) * torch.exp(-s)
log_det = -s.sum(dim=1)
y = torch.zeros_like(x)
y[:, mask] = x_masked
y[:, ~mask] = y_unmasked
return y, log_det
class NormalizingFlowPreprocessor(nn.Module):
"""Normalizing flow for learnable data preprocessing."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
input_dim = config.input_dim
hidden_dim = config.preprocessing_hidden_dim
num_layers = config.flow_coupling_layers
self.layers = nn.ModuleList()
for i in range(num_layers):
mask_type = "alternating" if i % 2 == 0 else "half"
self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type))
if config.use_batch_norm:
self.batch_norms = nn.ModuleList([nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)])
else:
self.batch_norms = None
def forward(self, x: torch.Tensor, inverse: bool = False):
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.size(-1))
log_det_total = torch.zeros(x.size(0), device=x.device)
if not inverse:
for i, layer in enumerate(self.layers):
x, log_det = layer(x, inverse=False)
log_det_total += log_det
if self.batch_norms and i < len(self.layers) - 1:
x = self.batch_norms[i](x)
else:
for i, layer in enumerate(reversed(self.layers)):
if self.batch_norms and i > 0:
bn_idx = len(self.layers) - 1 - i
x = self.batch_norms[bn_idx](x)
x, log_det = layer(x, inverse=True)
log_det_total += log_det
if len(original_shape) == 3:
x = x.view(original_shape)
reg_loss = 0.01 * log_det_total.abs().mean()
return x, reg_loss
class LearnablePreprocessor(nn.Module):
"""Unified interface for learnable preprocessing methods."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.config = config
if not config.has_preprocessing:
self.preprocessor = nn.Identity()
elif config.is_neural_scaler:
self.preprocessor = NeuralScaler(config)
elif config.is_normalizing_flow:
self.preprocessor = NormalizingFlowPreprocessor(config)
elif getattr(config, "is_minmax_scaler", False):
self.preprocessor = LearnableMinMaxScaler(config)
elif getattr(config, "is_robust_scaler", False):
self.preprocessor = LearnableRobustScaler(config)
elif getattr(config, "is_yeo_johnson", False):
self.preprocessor = LearnableYeoJohnsonPreprocessor(config)
else:
raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
def forward(self, x: torch.Tensor, inverse: bool = False):
if isinstance(self.preprocessor, nn.Identity):
return x, torch.tensor(0.0, device=x.device)
return self.preprocessor(x, inverse=inverse)
__all__ = [
"NeuralScaler",
"LearnableMinMaxScaler",
"LearnableRobustScaler",
"LearnableYeoJohnsonPreprocessor",
"CouplingLayer",
"NormalizingFlowPreprocessor",
"LearnablePreprocessor",
"PreprocessingBlock",
]