|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def _get_activation(name: Optional[str]) -> nn.Module: |
|
if name is None: |
|
return nn.Identity() |
|
name = name.lower() |
|
mapping = { |
|
"relu": nn.ReLU(), |
|
"gelu": nn.GELU(), |
|
"silu": nn.SiLU(), |
|
"swish": nn.SiLU(), |
|
"tanh": nn.Tanh(), |
|
"sigmoid": nn.Sigmoid(), |
|
"leaky_relu": nn.LeakyReLU(0.2), |
|
"elu": nn.ELU(), |
|
"mish": nn.Mish(), |
|
"softplus": nn.Softplus(), |
|
"identity": nn.Identity(), |
|
None: nn.Identity(), |
|
} |
|
if name not in mapping: |
|
raise ValueError(f"Unknown activation: {name}") |
|
return mapping[name] |
|
|
|
|
|
def _get_norm(name: Optional[str], num_features: int) -> nn.Module: |
|
if name is None or name == "none": |
|
return nn.Identity() |
|
name = name.lower() |
|
if name == "batch": |
|
return nn.BatchNorm1d(num_features) |
|
if name == "layer": |
|
return nn.LayerNorm(num_features) |
|
if name == "instance": |
|
return nn.InstanceNorm1d(num_features) |
|
if name == "group": |
|
|
|
groups = max(1, min(8, num_features)) |
|
|
|
while num_features % groups != 0 and groups > 1: |
|
groups -= 1 |
|
if groups == 1: |
|
return nn.LayerNorm(num_features) |
|
return nn.GroupNorm(groups, num_features) |
|
raise ValueError(f"Unknown normalization: {name}") |
|
|
|
|
|
def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: |
|
if x.dim() == 3: |
|
b, t, f = x.shape |
|
return x.reshape(b * t, f), (b, t) |
|
return x, None |
|
|
|
|
|
def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor: |
|
if shape_hint is None: |
|
return x |
|
b, t = shape_hint |
|
f = x.shape[-1] |
|
return x.reshape(b, t, f) |