from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F # ---------------------------- Utilities ---------------------------- # def _get_activation(name: Optional[str]) -> nn.Module: if name is None: return nn.Identity() name = name.lower() mapping = { "relu": nn.ReLU(), "gelu": nn.GELU(), "silu": nn.SiLU(), "swish": nn.SiLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "leaky_relu": nn.LeakyReLU(0.2), "elu": nn.ELU(), "mish": nn.Mish(), "softplus": nn.Softplus(), "identity": nn.Identity(), None: nn.Identity(), } if name not in mapping: raise ValueError(f"Unknown activation: {name}") return mapping[name] def _get_norm(name: Optional[str], num_features: int) -> nn.Module: if name is None or name == "none": return nn.Identity() name = name.lower() if name == "batch": return nn.BatchNorm1d(num_features) if name == "layer": return nn.LayerNorm(num_features) if name == "instance": return nn.InstanceNorm1d(num_features) if name == "group": # default 8 groups or min that divides groups = max(1, min(8, num_features)) # ensure divisible while num_features % groups != 0 and groups > 1: groups -= 1 if groups == 1: return nn.LayerNorm(num_features) return nn.GroupNorm(groups, num_features) raise ValueError(f"Unknown normalization: {name}") def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: if x.dim() == 3: b, t, f = x.shape return x.reshape(b * t, f), (b, t) return x, None def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor: if shape_hint is None: return x b, t = shape_hint f = x.shape[-1] return x.reshape(b, t, f)