File size: 2,095 Bytes
8abd44b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------- Utilities ---------------------------- #

def _get_activation(name: Optional[str]) -> nn.Module:
    if name is None:
        return nn.Identity()
    name = name.lower()
    mapping = {
        "relu": nn.ReLU(),
        "gelu": nn.GELU(),
        "silu": nn.SiLU(),
        "swish": nn.SiLU(),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "leaky_relu": nn.LeakyReLU(0.2),
        "elu": nn.ELU(),
        "mish": nn.Mish(),
        "softplus": nn.Softplus(),
        "identity": nn.Identity(),
        None: nn.Identity(),
    }
    if name not in mapping:
        raise ValueError(f"Unknown activation: {name}")
    return mapping[name]


def _get_norm(name: Optional[str], num_features: int) -> nn.Module:
    if name is None or name == "none":
        return nn.Identity()
    name = name.lower()
    if name == "batch":
        return nn.BatchNorm1d(num_features)
    if name == "layer":
        return nn.LayerNorm(num_features)
    if name == "instance":
        return nn.InstanceNorm1d(num_features)
    if name == "group":
        # default 8 groups or min that divides
        groups = max(1, min(8, num_features))
        # ensure divisible
        while num_features % groups != 0 and groups > 1:
            groups -= 1
        if groups == 1:
            return nn.LayerNorm(num_features)
        return nn.GroupNorm(groups, num_features)
    raise ValueError(f"Unknown normalization: {name}")


def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]:
    if x.dim() == 3:
        b, t, f = x.shape
        return x.reshape(b * t, f), (b, t)
    return x, None


def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor:
    if shape_hint is None:
        return x
    b, t = shape_hint
    f = x.shape[-1]
    return x.reshape(b, t, f)