Equidiff / equidiff /equi_diffpo /model /diffusion /transformer_for_diffusion.py
Lillianwei's picture
mimicgen
c1f1d32
from typing import Union, Optional, Tuple
import logging
import torch
import torch.nn as nn
from equi_diffpo.model.diffusion.positional_embedding import SinusoidalPosEmb
from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin
logger = logging.getLogger(__name__)
class TransformerForDiffusion(ModuleAttrMixin):
def __init__(self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 12,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool=False,
time_as_cond: bool=True,
obs_as_cond: bool=False,
n_cond_layers: int = 0
) -> None:
super().__init__()
# compute number of tokens for main trunk and condition encoder
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1
if not time_as_cond:
T += 1
T_cond -= 1
obs_as_cond = cond_dim > 0
if obs_as_cond:
assert time_as_cond
T_cond += n_obs_steps
# input embedding stem
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# cond encoder
self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_pos_emb = None
self.encoder = None
self.decoder = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4*n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
)
# decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4*n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # important for stability
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
)
else:
# encoder only BERT
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4*n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# attention mask
if causal_attn:
# causal mask to ensure that attention is only applied to the left in the input sequence
# torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
# therefore, the upper triangle should be -inf and others (including diag) should be 0.
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if time_as_cond and obs_as_cond:
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s-1) # add one dimension since time is the first token in cond
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
# decoder head
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# constants
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.encoder_only = encoder_only
# init
self.apply(self._init_weights)
logger.info(
"number of parameters: %e", sum(p.numel() for p in self.parameters())
)
def _init_weights(self, module):
ignore_types = (nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
weight_names = [
'in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
for name in weight_names:
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
for name in bias_names:
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, TransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
# no param
pass
else:
raise RuntimeError("Unaccounted module {}".format(module))
def get_optim_groups(self, weight_decay: float=1e-3):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.startswith("bias"):
# MultiheadAttention bias starts with "bias"
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add("pos_emb")
no_decay.add("_dummy_variable")
if self.cond_pos_emb is not None:
no_decay.add("cond_pos_emb")
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters %s were not separated into either decay/no_decay set!" % (
str(param_dict.keys() - union_params),
)
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
return optim_groups
def configure_optimizers(self,
learning_rate: float=1e-4,
weight_decay: float=1e-3,
betas: Tuple[float, float]=(0.9,0.95)):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(
optim_groups, lr=learning_rate, betas=betas
)
return optimizer
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor]=None, **kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
cond: (B,T',cond_dim)
output: (B,T,input_dim)
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1)
# (B,1,n_emb)
# process input
input_emb = self.input_emb(sample)
if self.encoder_only:
# BERT
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
position_embeddings = self.pos_emb[
:, :t, :
] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
# (B,T+1,n_emb)
x = self.encoder(src=x, mask=self.mask)
# (B,T+1,n_emb)
x = x[:,1:,:]
# (B,T,n_emb)
else:
# encoder
cond_embeddings = time_emb
if self.obs_as_cond:
cond_obs_emb = self.cond_obs_emb(cond)
# (B,To,n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
tc = cond_embeddings.shape[1]
position_embeddings = self.cond_pos_emb[
:, :tc, :
] # each position maps to a (learnable) vector
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
# (B,T_cond,n_emb)
# decoder
token_embeddings = input_emb
t = token_embeddings.shape[1]
position_embeddings = self.pos_emb[
:, :t, :
] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
# (B,T,n_emb)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
)
# (B,T,n_emb)
# head
x = self.ln_f(x)
x = self.head(x)
# (B,T,n_out)
return x
def test():
# GPT with time embedding
transformer = TransformerForDiffusion(
input_dim=16,
output_dim=16,
horizon=8,
n_obs_steps=4,
# cond_dim=10,
causal_attn=True,
# time_as_cond=False,
# n_cond_layers=4
)
opt = transformer.configure_optimizers()
timestep = torch.tensor(0)
sample = torch.zeros((4,8,16))
out = transformer(sample, timestep)
# GPT with time embedding and obs cond
transformer = TransformerForDiffusion(
input_dim=16,
output_dim=16,
horizon=8,
n_obs_steps=4,
cond_dim=10,
causal_attn=True,
# time_as_cond=False,
# n_cond_layers=4
)
opt = transformer.configure_optimizers()
timestep = torch.tensor(0)
sample = torch.zeros((4,8,16))
cond = torch.zeros((4,4,10))
out = transformer(sample, timestep, cond)
# GPT with time embedding and obs cond and encoder
transformer = TransformerForDiffusion(
input_dim=16,
output_dim=16,
horizon=8,
n_obs_steps=4,
cond_dim=10,
causal_attn=True,
# time_as_cond=False,
n_cond_layers=4
)
opt = transformer.configure_optimizers()
timestep = torch.tensor(0)
sample = torch.zeros((4,8,16))
cond = torch.zeros((4,4,10))
out = transformer(sample, timestep, cond)
# BERT with time embedding token
transformer = TransformerForDiffusion(
input_dim=16,
output_dim=16,
horizon=8,
n_obs_steps=4,
# cond_dim=10,
# causal_attn=True,
time_as_cond=False,
# n_cond_layers=4
)
opt = transformer.configure_optimizers()
timestep = torch.tensor(0)
sample = torch.zeros((4,8,16))
out = transformer(sample, timestep)