autoencoder / utils.py
AndrewMayesPrezzee
Feat - Block like transformer structure
8abd44b
raw
history blame
2.1 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------- Utilities ---------------------------- #
def _get_activation(name: Optional[str]) -> nn.Module:
if name is None:
return nn.Identity()
name = name.lower()
mapping = {
"relu": nn.ReLU(),
"gelu": nn.GELU(),
"silu": nn.SiLU(),
"swish": nn.SiLU(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
"leaky_relu": nn.LeakyReLU(0.2),
"elu": nn.ELU(),
"mish": nn.Mish(),
"softplus": nn.Softplus(),
"identity": nn.Identity(),
None: nn.Identity(),
}
if name not in mapping:
raise ValueError(f"Unknown activation: {name}")
return mapping[name]
def _get_norm(name: Optional[str], num_features: int) -> nn.Module:
if name is None or name == "none":
return nn.Identity()
name = name.lower()
if name == "batch":
return nn.BatchNorm1d(num_features)
if name == "layer":
return nn.LayerNorm(num_features)
if name == "instance":
return nn.InstanceNorm1d(num_features)
if name == "group":
# default 8 groups or min that divides
groups = max(1, min(8, num_features))
# ensure divisible
while num_features % groups != 0 and groups > 1:
groups -= 1
if groups == 1:
return nn.LayerNorm(num_features)
return nn.GroupNorm(groups, num_features)
raise ValueError(f"Unknown normalization: {name}")
def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]:
if x.dim() == 3:
b, t, f = x.shape
return x.reshape(b * t, f), (b, t)
return x, None
def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor:
if shape_hint is None:
return x
b, t = shape_hint
f = x.shape[-1]
return x.reshape(b, t, f)