|
""" |
|
Modular, block-based components for building autoencoders in PyTorch. |
|
|
|
Core goals: |
|
- Composable building blocks with consistent interfaces |
|
- Support 2D (B, F) and 3D (B, T, F) tensors where applicable |
|
- Simple configs to construct blocks and sequences |
|
- Safe-by-default validation and helpful errors |
|
|
|
This module is intentionally self-contained to allow gradual integration with |
|
existing models. It does not mutate current behavior. |
|
""" |
|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
from .configuration_autoencoder import ( |
|
BlockConfig, |
|
LinearBlockConfig, |
|
AttentionBlockConfig, |
|
RecurrentBlockConfig, |
|
ConvolutionalBlockConfig, |
|
VariationalBlockConfig, |
|
) |
|
except Exception: |
|
from configuration_autoencoder import ( |
|
BlockConfig, |
|
LinearBlockConfig, |
|
AttentionBlockConfig, |
|
RecurrentBlockConfig, |
|
ConvolutionalBlockConfig, |
|
VariationalBlockConfig, |
|
) |
|
|
|
|
|
|
|
try: |
|
from .utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d |
|
except Exception: |
|
from utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d |
|
|
|
|
|
|
|
|
|
class BaseBlock(nn.Module): |
|
"""Abstract base for all blocks. |
|
|
|
All blocks should accept 2D (B, F) or 3D (B, T, F) tensors and return the |
|
same rank, with last-dim equal to `output_dim`. |
|
""" |
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
class ResidualBlock(BaseBlock): |
|
"""Base class for blocks supporting residual connections. |
|
|
|
Implements a safe residual add when input and output dims match; otherwise |
|
falls back to a learned projection. Residuals can be scaled. |
|
""" |
|
|
|
def __init__(self, residual: bool = False, residual_scale: float = 1.0, proj_dim_in: Optional[int] = None, proj_dim_out: Optional[int] = None): |
|
super().__init__() |
|
self.use_residual = residual |
|
self.residual_scale = residual_scale |
|
self._proj: Optional[nn.Module] = None |
|
if residual and proj_dim_in is not None and proj_dim_out is not None and proj_dim_in != proj_dim_out: |
|
self._proj = nn.Linear(proj_dim_in, proj_dim_out) |
|
|
|
def _apply_residual(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
if not self.use_residual: |
|
return y |
|
x2d, hint = _flatten_3d_to_2d(x) |
|
y2d, _ = _flatten_3d_to_2d(y) |
|
if x2d.shape[-1] != y2d.shape[-1]: |
|
if self._proj is None: |
|
self._proj = nn.Linear(x2d.shape[-1], y2d.shape[-1]).to(y2d.device) |
|
x2d = self._proj(x2d) |
|
out = x2d + self.residual_scale * y2d |
|
return _maybe_restore_3d(out, hint) |
|
|
|
|
|
|
|
|
|
class LinearBlock(ResidualBlock): |
|
"""Basic linear transformation with normalization and activation. |
|
|
|
- Handles both 2D (B, F) and 3D (B, T, F) tensors |
|
- Optional normalization: batch|layer|group|instance|none |
|
- Configurable activation |
|
- Optional dropout |
|
- Optional residual connection (with auto projection) |
|
""" |
|
|
|
def __init__(self, cfg: LinearBlockConfig): |
|
super().__init__(residual=cfg.use_residual, residual_scale=cfg.residual_scale, proj_dim_in=cfg.input_dim, proj_dim_out=cfg.output_dim) |
|
self.cfg = cfg |
|
|
|
self.linear = nn.Linear(cfg.input_dim, cfg.output_dim) |
|
|
|
|
|
if cfg.normalization == "layer": |
|
self.norm = nn.LayerNorm(cfg.output_dim) |
|
else: |
|
self.norm = _get_norm(cfg.normalization, cfg.output_dim) |
|
self.act = _get_activation(cfg.activation) |
|
self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity() |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self.cfg.output_dim |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x_in = x |
|
x2d, hint = _flatten_3d_to_2d(x) |
|
y = self.linear(x2d) |
|
|
|
if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)): |
|
y = self.norm(y) |
|
else: |
|
|
|
y = self.norm(y) |
|
y = self.act(y) |
|
y = self.drop(y) |
|
y = _maybe_restore_3d(y, hint) |
|
return self._apply_residual(x_in, y) |
|
|
|
|
|
|
|
|
|
class AttentionBlock(BaseBlock): |
|
"""Multi-head self-attention with optional FFN. |
|
|
|
Expects inputs as 3D (B, T, D) or 2D (B, D) which will be treated as (B, 1, D). |
|
Supports optional attn mask and key padding mask via kwargs. |
|
""" |
|
|
|
def __init__(self, cfg: AttentionBlockConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
d_model = cfg.input_dim |
|
self.mha = nn.MultiheadAttention(d_model, num_heads=cfg.num_heads, dropout=cfg.dropout_rate, batch_first=True) |
|
self.ln1 = nn.LayerNorm(d_model) |
|
ffn_dim = cfg.ffn_dim or (4 * d_model) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(d_model, ffn_dim), |
|
_get_activation("gelu"), |
|
nn.Dropout(cfg.dropout_rate), |
|
nn.Linear(ffn_dim, d_model), |
|
) |
|
self.ln2 = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(cfg.dropout_rate) |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self.cfg.input_dim |
|
|
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
if x.dim() == 2: |
|
x = x.unsqueeze(1) |
|
squeeze_back = True |
|
else: |
|
squeeze_back = False |
|
|
|
residual = x |
|
attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) |
|
x = self.ln1(residual + self.dropout(attn_out)) |
|
|
|
residual = x |
|
x = self.ffn(x) |
|
x = self.ln2(residual + self.dropout(x)) |
|
if squeeze_back: |
|
x = x.squeeze(1) |
|
return x |
|
|
|
|
|
|
|
|
|
class RecurrentBlock(BaseBlock): |
|
"""RNN processing block supporting LSTM/GRU/RNN. |
|
|
|
Input: 3D (B, T, F) preferred. If 2D, treated as (B, 1, F). |
|
Output dim equals cfg.output_dim if set; otherwise hidden_size * directions. |
|
""" |
|
|
|
def __init__(self, cfg: RecurrentBlockConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
rnn_type = cfg.rnn_type.lower() |
|
rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN}.get(rnn_type) |
|
if rnn_cls is None: |
|
raise ValueError(f"Unknown rnn_type: {cfg.rnn_type}") |
|
self.rnn = rnn_cls( |
|
input_size=cfg.input_dim, |
|
hidden_size=cfg.hidden_size, |
|
num_layers=cfg.num_layers, |
|
batch_first=True, |
|
dropout=cfg.dropout_rate if cfg.num_layers > 1 else 0.0, |
|
bidirectional=cfg.bidirectional, |
|
) |
|
out_dim = cfg.hidden_size * (2 if cfg.bidirectional else 1) |
|
self._out_dim = cfg.output_dim or out_dim |
|
self.proj = None if self._out_dim == out_dim else nn.Linear(out_dim, self._out_dim) |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self._out_dim |
|
|
|
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
squeeze_back = False |
|
if x.dim() == 2: |
|
x = x.unsqueeze(1) |
|
squeeze_back = True |
|
if lengths is not None: |
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
|
if isinstance(self.rnn, nn.LSTM): |
|
out, (h, c) = self.rnn(x) |
|
else: |
|
out, h = self.rnn(x) |
|
if lengths is not None: |
|
out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) |
|
|
|
y = out[:, -1, :] |
|
if self.proj is not None: |
|
y = self.proj(y) |
|
if squeeze_back: |
|
|
|
return y |
|
|
|
return y.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
class ConvolutionalBlock(BaseBlock): |
|
"""1D convolutional block for sequence-like data. |
|
Accepts 3D (B, T, F) or 2D (B, F) which is treated as (B, 1, F). |
|
""" |
|
|
|
def __init__(self, cfg: ConvolutionalBlockConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
|
|
padding = cfg.padding |
|
if isinstance(padding, str) and padding == "same": |
|
pad = cfg.kernel_size // 2 |
|
else: |
|
pad = int(padding) |
|
self.conv = nn.Conv1d(cfg.input_dim, cfg.output_dim, kernel_size=cfg.kernel_size, padding=pad) |
|
|
|
if cfg.normalization == "layer": |
|
self.norm = nn.GroupNorm(1, cfg.output_dim) |
|
else: |
|
self.norm = _get_norm(cfg.normalization, cfg.output_dim) |
|
self.act = _get_activation(cfg.activation) |
|
self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity() |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self.cfg.output_dim |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
squeeze_back = False |
|
if x.dim() == 2: |
|
x = x.unsqueeze(1) |
|
squeeze_back = True |
|
|
|
x = x.transpose(1, 2) |
|
|
|
y = self.conv(x) |
|
if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)): |
|
y = self.norm(y) |
|
y = self.act(y) |
|
y = self.drop(y) |
|
y = y.transpose(1, 2) |
|
if squeeze_back: |
|
y = y.squeeze(1) |
|
return y |
|
|
|
|
|
|
|
class VariationalBlock(BaseBlock): |
|
"""Encapsulates mu/logvar projection and reparameterization. |
|
|
|
Input can be 2D (B, F) or 3D (B, T, F); for 3D, operates per timestep and returns same rank. |
|
Stores mu/logvar on the module for downstream loss usage. |
|
""" |
|
|
|
def __init__(self, cfg: VariationalBlockConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.fc_mu = nn.Linear(cfg.input_dim, cfg.latent_dim) |
|
self.fc_logvar = nn.Linear(cfg.input_dim, cfg.latent_dim) |
|
self._mu: Optional[torch.Tensor] = None |
|
self._logvar: Optional[torch.Tensor] = None |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self.cfg.latent_dim |
|
|
|
def forward(self, x: torch.Tensor, training: Optional[bool] = None) -> torch.Tensor: |
|
if training is None: |
|
training = self.training |
|
x2d, hint = _flatten_3d_to_2d(x) |
|
mu = self.fc_mu(x2d) |
|
logvar = self.fc_logvar(x2d) |
|
if training: |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
z = mu + eps * std |
|
else: |
|
z = mu |
|
self._mu = mu |
|
self._logvar = logvar |
|
z = _maybe_restore_3d(z, hint) |
|
return z |
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockSequence(nn.Module): |
|
"""Compose multiple blocks into a validated sequence. |
|
|
|
- Validates dimension flow between blocks |
|
- Supports gradient checkpointing (per-block) via forward(checkpoint=True) |
|
- Supports optional skip connections: pass `skips` as list of (src_idx, dst_idx) |
|
""" |
|
|
|
def __init__(self, blocks: Sequence[BaseBlock], validate_dims: bool = True, skips: Optional[List[Tuple[int, int]]] = None): |
|
super().__init__() |
|
self.blocks = nn.ModuleList(blocks) |
|
self.skips = skips or [] |
|
if validate_dims and len(blocks) > 1: |
|
for i in range(1, len(blocks)): |
|
prev = blocks[i - 1] |
|
cur = blocks[i] |
|
if getattr(prev, "output_dim", None) is None or getattr(cur, "output_dim", None) is None: |
|
continue |
|
if prev.output_dim != cur.output_dim and not isinstance(cur, LinearBlock): |
|
|
|
pass |
|
|
|
def forward(self, x: torch.Tensor, checkpoint: bool = False, **kwargs) -> torch.Tensor: |
|
activations: Dict[int, torch.Tensor] = {} |
|
for i, block in enumerate(self.blocks): |
|
if checkpoint and x.requires_grad: |
|
x = torch.utils.checkpoint.checkpoint(lambda inp: block(inp, **kwargs), x) |
|
else: |
|
x = block(x, **kwargs) |
|
activations[i] = x |
|
|
|
for src, dst in self.skips: |
|
if dst == i and src in activations: |
|
x = x + activations[src] |
|
return x |
|
|
|
|
|
|
|
|
|
class BlockFactory: |
|
"""Factory to build blocks/sequences from configs. |
|
|
|
This is intentionally minimal; extend as needed. |
|
""" |
|
|
|
@staticmethod |
|
def build_block(cfg: Union[BlockConfig, Dict[str, Any]]) -> BaseBlock: |
|
|
|
if isinstance(cfg, dict): |
|
type_name = cfg.get("type") |
|
|
|
params = dict(cfg) |
|
params.pop("type", None) |
|
if type_name == "linear": |
|
return LinearBlock(LinearBlockConfig(**params)) |
|
if type_name == "attention": |
|
return AttentionBlock(AttentionBlockConfig(**params)) |
|
if type_name == "recurrent": |
|
return RecurrentBlock(RecurrentBlockConfig(**params)) |
|
if type_name == "conv1d": |
|
return ConvolutionalBlock(ConvolutionalBlockConfig(**params)) |
|
raise ValueError(f"Unsupported block type in dict cfg: {type_name} cfg={cfg}") |
|
|
|
if isinstance(cfg, LinearBlockConfig) or getattr(cfg, "type", None) == "linear": |
|
if not isinstance(cfg, LinearBlockConfig): |
|
cfg = LinearBlockConfig(**cfg.__dict__) |
|
return LinearBlock(cfg) |
|
if isinstance(cfg, AttentionBlockConfig) or getattr(cfg, "type", None) == "attention": |
|
if not isinstance(cfg, AttentionBlockConfig): |
|
cfg = AttentionBlockConfig(**cfg.__dict__) |
|
return AttentionBlock(cfg) |
|
if isinstance(cfg, RecurrentBlockConfig) or getattr(cfg, "type", None) == "recurrent": |
|
if not isinstance(cfg, RecurrentBlockConfig): |
|
cfg = RecurrentBlockConfig(**cfg.__dict__) |
|
return RecurrentBlock(cfg) |
|
if isinstance(cfg, ConvolutionalBlockConfig) or getattr(cfg, "type", None) == "conv1d": |
|
if not isinstance(cfg, ConvolutionalBlockConfig): |
|
cfg = ConvolutionalBlockConfig(**cfg.__dict__) |
|
return ConvolutionalBlock(cfg) |
|
if isinstance(cfg, VariationalBlockConfig) or getattr(cfg, "type", None) == "variational": |
|
if not isinstance(cfg, VariationalBlockConfig): |
|
cfg = VariationalBlockConfig(**cfg.__dict__) |
|
return VariationalBlock(cfg) |
|
raise ValueError(f"Unsupported block type: {cfg}") |
|
|
|
@staticmethod |
|
def build_sequence(configs: Sequence[Union[BlockConfig, Dict[str, Any]]]) -> BlockSequence: |
|
blocks: List[BaseBlock] = [BlockFactory.build_block(c) for c in configs] |
|
return BlockSequence(blocks) |
|
|
|
|
|
__all__ = [ |
|
"BlockConfig", |
|
"LinearBlockConfig", |
|
"AttentionBlockConfig", |
|
"RecurrentBlockConfig", |
|
"ConvolutionalBlockConfig", |
|
"VariationalBlockConfig", |
|
"BaseBlock", |
|
"ResidualBlock", |
|
"LinearBlock", |
|
"AttentionBlock", |
|
"RecurrentBlock", |
|
"ConvolutionalBlock", |
|
"VariationalBlock", |
|
"BlockSequence", |
|
"BlockFactory", |
|
] |
|
|
|
|