|
|
|
|
|
""" |
|
Learnable preprocessing components for the block-based autoencoder. |
|
Extracted from modeling_autoencoder.py to a dedicated module. |
|
""" |
|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from typing import Tuple |
|
|
|
try: |
|
from .blocks import BaseBlock |
|
except Exception: |
|
from blocks import BaseBlock |
|
|
|
import torch.nn as nn |
|
|
|
try: |
|
from .configuration_autoencoder import AutoencoderConfig |
|
except Exception: |
|
from configuration_autoencoder import AutoencoderConfig |
|
|
|
|
|
class NeuralScaler(nn.Module): |
|
"""Learnable alternative to StandardScaler using neural networks.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
|
|
self.mean_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
|
) |
|
self.std_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
|
) |
|
|
|
self.weight = nn.Parameter(torch.ones(input_dim)) |
|
self.bias = nn.Parameter(torch.zeros(input_dim)) |
|
|
|
self.register_buffer("running_mean", torch.zeros(input_dim)) |
|
self.register_buffer("running_std", torch.ones(input_dim)) |
|
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
|
self.momentum = 0.1 |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if inverse: |
|
return self._inverse_transform(x) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
if self.training: |
|
batch_mean = x.mean(dim=0, keepdim=True) |
|
batch_std = x.std(dim=0, keepdim=True) |
|
learned_mean_adj = self.mean_estimator(batch_mean) |
|
learned_std_adj = self.std_estimator(batch_std) |
|
effective_mean = batch_mean + learned_mean_adj |
|
effective_std = batch_std + learned_std_adj + 1e-8 |
|
with torch.no_grad(): |
|
self.num_batches_tracked += 1 |
|
if self.num_batches_tracked == 1: |
|
self.running_mean.copy_(batch_mean.squeeze()) |
|
self.running_std.copy_(batch_std.squeeze()) |
|
else: |
|
self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) |
|
self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) |
|
else: |
|
effective_mean = self.running_mean.unsqueeze(0) |
|
effective_std = self.running_std.unsqueeze(0) + 1e-8 |
|
normalized = (x - effective_mean) / effective_std |
|
transformed = normalized * self.weight + self.bias |
|
if len(original_shape) == 3: |
|
transformed = transformed.view(original_shape) |
|
reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
|
return transformed, reg_loss |
|
|
|
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if not self.config.learn_inverse_preprocessing: |
|
return x, torch.tensor(0.0, device=x.device) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
x = (x - self.bias) / (self.weight + 1e-8) |
|
effective_mean = self.running_mean.unsqueeze(0) |
|
effective_std = self.running_std.unsqueeze(0) + 1e-8 |
|
x = x * effective_std + effective_mean |
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
class LearnableMinMaxScaler(nn.Module): |
|
"""Learnable MinMax scaler that adapts bounds during training.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
self.min_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
|
) |
|
self.range_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
|
) |
|
self.weight = nn.Parameter(torch.ones(input_dim)) |
|
self.bias = nn.Parameter(torch.zeros(input_dim)) |
|
self.register_buffer("running_min", torch.zeros(input_dim)) |
|
self.register_buffer("running_range", torch.ones(input_dim)) |
|
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
|
self.momentum = 0.1 |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if inverse: |
|
return self._inverse_transform(x) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
eps = 1e-8 |
|
if self.training: |
|
batch_min = x.min(dim=0, keepdim=True).values |
|
batch_max = x.max(dim=0, keepdim=True).values |
|
batch_range = (batch_max - batch_min).clamp_min(eps) |
|
learned_min_adj = self.min_estimator(batch_min) |
|
learned_range_adj = self.range_estimator(batch_range) |
|
effective_min = batch_min + learned_min_adj |
|
effective_range = batch_range + learned_range_adj + eps |
|
with torch.no_grad(): |
|
self.num_batches_tracked += 1 |
|
if self.num_batches_tracked == 1: |
|
self.running_min.copy_(batch_min.squeeze()) |
|
self.running_range.copy_(batch_range.squeeze()) |
|
else: |
|
self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum) |
|
self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum) |
|
else: |
|
effective_min = self.running_min.unsqueeze(0) |
|
effective_range = self.running_range.unsqueeze(0) |
|
scaled = (x - effective_min) / effective_range |
|
transformed = scaled * self.weight + self.bias |
|
if len(original_shape) == 3: |
|
transformed = transformed.view(original_shape) |
|
reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
|
if self.training: |
|
reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean() |
|
return transformed, reg_loss |
|
|
|
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if not self.config.learn_inverse_preprocessing: |
|
return x, torch.tensor(0.0, device=x.device) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
x = (x - self.bias) / (self.weight + 1e-8) |
|
x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0) |
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
class LearnableRobustScaler(nn.Module): |
|
"""Learnable Robust scaler using median and IQR with learnable adjustments.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
self.median_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
|
) |
|
self.iqr_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
|
) |
|
self.weight = nn.Parameter(torch.ones(input_dim)) |
|
self.bias = nn.Parameter(torch.zeros(input_dim)) |
|
self.register_buffer("running_median", torch.zeros(input_dim)) |
|
self.register_buffer("running_iqr", torch.ones(input_dim)) |
|
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
|
self.momentum = 0.1 |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if inverse: |
|
return self._inverse_transform(x) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
eps = 1e-8 |
|
if self.training: |
|
qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0) |
|
q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :] |
|
iqr = (q75 - q25).clamp_min(eps) |
|
learned_med_adj = self.median_estimator(med) |
|
learned_iqr_adj = self.iqr_estimator(iqr) |
|
effective_median = med + learned_med_adj |
|
effective_iqr = iqr + learned_iqr_adj + eps |
|
with torch.no_grad(): |
|
self.num_batches_tracked += 1 |
|
if self.num_batches_tracked == 1: |
|
self.running_median.copy_(med.squeeze()) |
|
self.running_iqr.copy_(iqr.squeeze()) |
|
else: |
|
self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum) |
|
self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum) |
|
else: |
|
effective_median = self.running_median.unsqueeze(0) |
|
effective_iqr = self.running_iqr.unsqueeze(0) |
|
normalized = (x - effective_median) / effective_iqr |
|
transformed = normalized * self.weight + self.bias |
|
if len(original_shape) == 3: |
|
transformed = transformed.view(original_shape) |
|
reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
|
if self.training: |
|
reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean() |
|
return transformed, reg_loss |
|
|
|
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if not self.config.learn_inverse_preprocessing: |
|
return x, torch.tensor(0.0, device=x.device) |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
x = (x - self.bias) / (self.weight + 1e-8) |
|
x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0) |
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
class LearnableYeoJohnsonPreprocessor(nn.Module): |
|
"""Learnable Yeo-Johnson power transform with per-feature lambda and affine head.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
self.lmbda = nn.Parameter(torch.ones(input_dim)) |
|
self.weight = nn.Parameter(torch.ones(input_dim)) |
|
self.bias = nn.Parameter(torch.zeros(input_dim)) |
|
self.register_buffer("running_mean", torch.zeros(input_dim)) |
|
self.register_buffer("running_std", torch.ones(input_dim)) |
|
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
|
self.momentum = 0.1 |
|
|
|
def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: |
|
eps = 1e-6 |
|
lmbda = lmbda.unsqueeze(0) |
|
pos = x >= 0 |
|
if_part = torch.where(torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps))) |
|
two_minus_lambda = 2.0 - lmbda |
|
else_part = torch.where(torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps))) |
|
return torch.where(pos, if_part, else_part) |
|
|
|
def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: |
|
eps = 1e-6 |
|
lmbda = lmbda.unsqueeze(0) |
|
pos = y >= 0 |
|
x_pos = torch.where(torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0) |
|
two_minus_lambda = 2.0 - lmbda |
|
x_neg = torch.where(torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y)) |
|
return torch.where(pos, x_pos, x_neg) |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if inverse: |
|
return self._inverse_transform(x) |
|
orig_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
y = self._yeo_johnson(x, self.lmbda) |
|
if self.training: |
|
batch_mean = y.mean(dim=0, keepdim=True) |
|
batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6) |
|
with torch.no_grad(): |
|
self.num_batches_tracked += 1 |
|
if self.num_batches_tracked == 1: |
|
self.running_mean.copy_(batch_mean.squeeze()) |
|
self.running_std.copy_(batch_std.squeeze()) |
|
else: |
|
self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) |
|
self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) |
|
mean = batch_mean |
|
std = batch_std |
|
else: |
|
mean = self.running_mean.unsqueeze(0) |
|
std = self.running_std.unsqueeze(0) |
|
y_norm = (y - mean) / std |
|
out = y_norm * self.weight + self.bias |
|
if len(orig_shape) == 3: |
|
out = out.view(orig_shape) |
|
reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var()) |
|
return out, reg |
|
|
|
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if not self.config.learn_inverse_preprocessing: |
|
return x, torch.tensor(0.0, device=x.device) |
|
orig_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
y = (x - self.bias) / (self.weight + 1e-8) |
|
y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0) |
|
out = self._yeo_johnson_inverse(y, self.lmbda) |
|
if len(orig_shape) == 3: |
|
out = out.view(orig_shape) |
|
return out, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
|
|
class PreprocessingBlock(BaseBlock): |
|
"""Wraps a LearnablePreprocessor into a BaseBlock-compatible interface. |
|
Forward returns the transformed tensor and stores the regularization loss in .reg_loss. |
|
The inverse flag is configured at initialization to avoid leaking kwargs to other blocks. |
|
""" |
|
|
|
def __init__(self, config: AutoencoderConfig, inverse: bool = False, proc: Optional[LearnablePreprocessor] = None): |
|
super().__init__() |
|
self.proc = proc if proc is not None else LearnablePreprocessor(config) |
|
self._output_dim = config.input_dim |
|
self.inverse = inverse |
|
self.reg_loss: torch.Tensor = torch.tensor(0.0) |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self._output_dim |
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
y, reg = self.proc(x, inverse=self.inverse) |
|
self.reg_loss = reg |
|
return y |
|
|
|
class CouplingLayer(nn.Module): |
|
"""Coupling layer for normalizing flows.""" |
|
|
|
def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
if mask_type == "alternating": |
|
self.register_buffer("mask", torch.arange(input_dim) % 2) |
|
elif mask_type == "half": |
|
mask = torch.zeros(input_dim) |
|
mask[: input_dim // 2] = 1 |
|
self.register_buffer("mask", mask) |
|
else: |
|
raise ValueError(f"Unknown mask type: {mask_type}") |
|
masked_dim = int(self.mask.sum().item()) |
|
unmasked_dim = input_dim - masked_dim |
|
self.scale_net = nn.Sequential( |
|
nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh() |
|
) |
|
self.translate_net = nn.Sequential( |
|
nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False): |
|
mask = self.mask.bool() |
|
x_masked = x[:, mask] |
|
x_unmasked = x[:, ~mask] |
|
s = self.scale_net(x_masked) |
|
t = self.translate_net(x_masked) |
|
if not inverse: |
|
y_unmasked = x_unmasked * torch.exp(s) + t |
|
log_det = s.sum(dim=1) |
|
else: |
|
y_unmasked = (x_unmasked - t) * torch.exp(-s) |
|
log_det = -s.sum(dim=1) |
|
y = torch.zeros_like(x) |
|
y[:, mask] = x_masked |
|
y[:, ~mask] = y_unmasked |
|
return y, log_det |
|
|
|
|
|
class NormalizingFlowPreprocessor(nn.Module): |
|
"""Normalizing flow for learnable data preprocessing.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
num_layers = config.flow_coupling_layers |
|
self.layers = nn.ModuleList() |
|
for i in range(num_layers): |
|
mask_type = "alternating" if i % 2 == 0 else "half" |
|
self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) |
|
if config.use_batch_norm: |
|
self.batch_norms = nn.ModuleList([nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)]) |
|
else: |
|
self.batch_norms = None |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False): |
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
x = x.view(-1, x.size(-1)) |
|
log_det_total = torch.zeros(x.size(0), device=x.device) |
|
if not inverse: |
|
for i, layer in enumerate(self.layers): |
|
x, log_det = layer(x, inverse=False) |
|
log_det_total += log_det |
|
if self.batch_norms and i < len(self.layers) - 1: |
|
x = self.batch_norms[i](x) |
|
else: |
|
for i, layer in enumerate(reversed(self.layers)): |
|
if self.batch_norms and i > 0: |
|
bn_idx = len(self.layers) - 1 - i |
|
x = self.batch_norms[bn_idx](x) |
|
x, log_det = layer(x, inverse=True) |
|
log_det_total += log_det |
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
reg_loss = 0.01 * log_det_total.abs().mean() |
|
return x, reg_loss |
|
|
|
|
|
class LearnablePreprocessor(nn.Module): |
|
"""Unified interface for learnable preprocessing methods.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
if not config.has_preprocessing: |
|
self.preprocessor = nn.Identity() |
|
elif config.is_neural_scaler: |
|
self.preprocessor = NeuralScaler(config) |
|
elif config.is_normalizing_flow: |
|
self.preprocessor = NormalizingFlowPreprocessor(config) |
|
elif getattr(config, "is_minmax_scaler", False): |
|
self.preprocessor = LearnableMinMaxScaler(config) |
|
elif getattr(config, "is_robust_scaler", False): |
|
self.preprocessor = LearnableRobustScaler(config) |
|
elif getattr(config, "is_yeo_johnson", False): |
|
self.preprocessor = LearnableYeoJohnsonPreprocessor(config) |
|
else: |
|
raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False): |
|
if isinstance(self.preprocessor, nn.Identity): |
|
return x, torch.tensor(0.0, device=x.device) |
|
return self.preprocessor(x, inverse=inverse) |
|
|
|
|
|
|
|
__all__ = [ |
|
"NeuralScaler", |
|
"LearnableMinMaxScaler", |
|
"LearnableRobustScaler", |
|
"LearnableYeoJohnsonPreprocessor", |
|
"CouplingLayer", |
|
"NormalizingFlowPreprocessor", |
|
"LearnablePreprocessor", |
|
"PreprocessingBlock", |
|
] |
|
|