# flake8: noqa """ Learnable preprocessing components for the block-based autoencoder. Extracted from modeling_autoencoder.py to a dedicated module. """ from __future__ import annotations from dataclasses import dataclass from typing import Optional, Tuple import torch from typing import Tuple try: from .blocks import BaseBlock except Exception: from blocks import BaseBlock import torch.nn as nn try: from .configuration_autoencoder import AutoencoderConfig # when loaded via HF dynamic module except Exception: from configuration_autoencoder import AutoencoderConfig # local usage class NeuralScaler(nn.Module): """Learnable alternative to StandardScaler using neural networks.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim self.mean_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) self.std_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() ) self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) self.register_buffer("running_mean", torch.zeros(input_dim)) self.register_buffer("running_std", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) if self.training: batch_mean = x.mean(dim=0, keepdim=True) batch_std = x.std(dim=0, keepdim=True) learned_mean_adj = self.mean_estimator(batch_mean) learned_std_adj = self.std_estimator(batch_std) effective_mean = batch_mean + learned_mean_adj effective_std = batch_std + learned_std_adj + 1e-8 with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_mean.copy_(batch_mean.squeeze()) self.running_std.copy_(batch_std.squeeze()) else: self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) else: effective_mean = self.running_mean.unsqueeze(0) effective_std = self.running_std.unsqueeze(0) + 1e-8 normalized = (x - effective_mean) / effective_std transformed = normalized * self.weight + self.bias if len(original_shape) == 3: transformed = transformed.view(original_shape) reg_loss = 0.01 * (self.weight.var() + self.bias.var()) return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) x = (x - self.bias) / (self.weight + 1e-8) effective_mean = self.running_mean.unsqueeze(0) effective_std = self.running_std.unsqueeze(0) + 1e-8 x = x * effective_std + effective_mean if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableMinMaxScaler(nn.Module): """Learnable MinMax scaler that adapts bounds during training.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim self.min_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) self.range_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() ) self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) self.register_buffer("running_min", torch.zeros(input_dim)) self.register_buffer("running_range", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) eps = 1e-8 if self.training: batch_min = x.min(dim=0, keepdim=True).values batch_max = x.max(dim=0, keepdim=True).values batch_range = (batch_max - batch_min).clamp_min(eps) learned_min_adj = self.min_estimator(batch_min) learned_range_adj = self.range_estimator(batch_range) effective_min = batch_min + learned_min_adj effective_range = batch_range + learned_range_adj + eps with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_min.copy_(batch_min.squeeze()) self.running_range.copy_(batch_range.squeeze()) else: self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum) self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum) else: effective_min = self.running_min.unsqueeze(0) effective_range = self.running_range.unsqueeze(0) scaled = (x - effective_min) / effective_range transformed = scaled * self.weight + self.bias if len(original_shape) == 3: transformed = transformed.view(original_shape) reg_loss = 0.01 * (self.weight.var() + self.bias.var()) if self.training: reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean() return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) x = (x - self.bias) / (self.weight + 1e-8) x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0) if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableRobustScaler(nn.Module): """Learnable Robust scaler using median and IQR with learnable adjustments.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim self.median_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) self.iqr_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() ) self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) self.register_buffer("running_median", torch.zeros(input_dim)) self.register_buffer("running_iqr", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) eps = 1e-8 if self.training: qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0) q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :] iqr = (q75 - q25).clamp_min(eps) learned_med_adj = self.median_estimator(med) learned_iqr_adj = self.iqr_estimator(iqr) effective_median = med + learned_med_adj effective_iqr = iqr + learned_iqr_adj + eps with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_median.copy_(med.squeeze()) self.running_iqr.copy_(iqr.squeeze()) else: self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum) self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum) else: effective_median = self.running_median.unsqueeze(0) effective_iqr = self.running_iqr.unsqueeze(0) normalized = (x - effective_median) / effective_iqr transformed = normalized * self.weight + self.bias if len(original_shape) == 3: transformed = transformed.view(original_shape) reg_loss = 0.01 * (self.weight.var() + self.bias.var()) if self.training: reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean() return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) x = (x - self.bias) / (self.weight + 1e-8) x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0) if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableYeoJohnsonPreprocessor(nn.Module): """Learnable Yeo-Johnson power transform with per-feature lambda and affine head.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim self.lmbda = nn.Parameter(torch.ones(input_dim)) self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) self.register_buffer("running_mean", torch.zeros(input_dim)) self.register_buffer("running_std", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: eps = 1e-6 lmbda = lmbda.unsqueeze(0) pos = x >= 0 if_part = torch.where(torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps))) two_minus_lambda = 2.0 - lmbda else_part = torch.where(torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps))) return torch.where(pos, if_part, else_part) def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: eps = 1e-6 lmbda = lmbda.unsqueeze(0) pos = y >= 0 x_pos = torch.where(torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0) two_minus_lambda = 2.0 - lmbda x_neg = torch.where(torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y)) return torch.where(pos, x_pos, x_neg) def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) orig_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) y = self._yeo_johnson(x, self.lmbda) if self.training: batch_mean = y.mean(dim=0, keepdim=True) batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6) with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_mean.copy_(batch_mean.squeeze()) self.running_std.copy_(batch_std.squeeze()) else: self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) mean = batch_mean std = batch_std else: mean = self.running_mean.unsqueeze(0) std = self.running_std.unsqueeze(0) y_norm = (y - mean) / std out = y_norm * self.weight + self.bias if len(orig_shape) == 3: out = out.view(orig_shape) reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var()) return out, reg def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) orig_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) y = (x - self.bias) / (self.weight + 1e-8) y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0) out = self._yeo_johnson_inverse(y, self.lmbda) if len(orig_shape) == 3: out = out.view(orig_shape) return out, torch.tensor(0.0, device=x.device) class PreprocessingBlock(BaseBlock): """Wraps a LearnablePreprocessor into a BaseBlock-compatible interface. Forward returns the transformed tensor and stores the regularization loss in .reg_loss. The inverse flag is configured at initialization to avoid leaking kwargs to other blocks. """ def __init__(self, config: AutoencoderConfig, inverse: bool = False, proc: Optional[LearnablePreprocessor] = None): super().__init__() self.proc = proc if proc is not None else LearnablePreprocessor(config) self._output_dim = config.input_dim self.inverse = inverse self.reg_loss: torch.Tensor = torch.tensor(0.0) @property def output_dim(self) -> int: return self._output_dim def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: y, reg = self.proc(x, inverse=self.inverse) self.reg_loss = reg return y class CouplingLayer(nn.Module): """Coupling layer for normalizing flows.""" def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim if mask_type == "alternating": self.register_buffer("mask", torch.arange(input_dim) % 2) elif mask_type == "half": mask = torch.zeros(input_dim) mask[: input_dim // 2] = 1 self.register_buffer("mask", mask) else: raise ValueError(f"Unknown mask type: {mask_type}") masked_dim = int(self.mask.sum().item()) unmasked_dim = input_dim - masked_dim self.scale_net = nn.Sequential( nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh() ) self.translate_net = nn.Sequential( nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim) ) def forward(self, x: torch.Tensor, inverse: bool = False): mask = self.mask.bool() x_masked = x[:, mask] x_unmasked = x[:, ~mask] s = self.scale_net(x_masked) t = self.translate_net(x_masked) if not inverse: y_unmasked = x_unmasked * torch.exp(s) + t log_det = s.sum(dim=1) else: y_unmasked = (x_unmasked - t) * torch.exp(-s) log_det = -s.sum(dim=1) y = torch.zeros_like(x) y[:, mask] = x_masked y[:, ~mask] = y_unmasked return y, log_det class NormalizingFlowPreprocessor(nn.Module): """Normalizing flow for learnable data preprocessing.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim num_layers = config.flow_coupling_layers self.layers = nn.ModuleList() for i in range(num_layers): mask_type = "alternating" if i % 2 == 0 else "half" self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) if config.use_batch_norm: self.batch_norms = nn.ModuleList([nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)]) else: self.batch_norms = None def forward(self, x: torch.Tensor, inverse: bool = False): original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) log_det_total = torch.zeros(x.size(0), device=x.device) if not inverse: for i, layer in enumerate(self.layers): x, log_det = layer(x, inverse=False) log_det_total += log_det if self.batch_norms and i < len(self.layers) - 1: x = self.batch_norms[i](x) else: for i, layer in enumerate(reversed(self.layers)): if self.batch_norms and i > 0: bn_idx = len(self.layers) - 1 - i x = self.batch_norms[bn_idx](x) x, log_det = layer(x, inverse=True) log_det_total += log_det if len(original_shape) == 3: x = x.view(original_shape) reg_loss = 0.01 * log_det_total.abs().mean() return x, reg_loss class LearnablePreprocessor(nn.Module): """Unified interface for learnable preprocessing methods.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config if not config.has_preprocessing: self.preprocessor = nn.Identity() elif config.is_neural_scaler: self.preprocessor = NeuralScaler(config) elif config.is_normalizing_flow: self.preprocessor = NormalizingFlowPreprocessor(config) elif getattr(config, "is_minmax_scaler", False): self.preprocessor = LearnableMinMaxScaler(config) elif getattr(config, "is_robust_scaler", False): self.preprocessor = LearnableRobustScaler(config) elif getattr(config, "is_yeo_johnson", False): self.preprocessor = LearnableYeoJohnsonPreprocessor(config) else: raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") def forward(self, x: torch.Tensor, inverse: bool = False): if isinstance(self.preprocessor, nn.Identity): return x, torch.tensor(0.0, device=x.device) return self.preprocessor(x, inverse=inverse) __all__ = [ "NeuralScaler", "LearnableMinMaxScaler", "LearnableRobustScaler", "LearnableYeoJohnsonPreprocessor", "CouplingLayer", "NormalizingFlowPreprocessor", "LearnablePreprocessor", "PreprocessingBlock", ]