from typing import Union import logging import torch import torch.nn as nn import torch.nn.functional as F import einops from einops.layers.torch import Rearrange from termcolor import cprint from equi_diffpo.model.diffusion.conv1d_components import ( Downsample1d, Upsample1d, Conv1dBlock) from equi_diffpo.model.diffusion.positional_embedding import SinusoidalPosEmb logger = logging.getLogger(__name__) class CrossAttention(nn.Module): def __init__(self, in_dim, cond_dim, out_dim): super().__init__() self.query_proj = nn.Linear(in_dim, out_dim) self.key_proj = nn.Linear(cond_dim, out_dim) self.value_proj = nn.Linear(cond_dim, out_dim) def forward(self, x, cond): # x: [batch_size, t_act, in_dim] # cond: [batch_size, t_obs, cond_dim] # Project x and cond to query, key, and value query = self.query_proj(x) # [batch_size, horizon, out_dim] key = self.key_proj(cond) # [batch_size, horizon, out_dim] value = self.value_proj(cond) # [batch_size, horizon, out_dim] # Compute attention attn_weights = torch.matmul(query, key.transpose(-2, -1)) # [batch_size, horizon, horizon] attn_weights = F.softmax(attn_weights, dim=-1) # Apply attention attn_output = torch.matmul(attn_weights, value) # [batch_size, horizon, out_dim] return attn_output class ConditionalResidualBlock1D(nn.Module): def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, condition_type='film'): super().__init__() self.blocks = nn.ModuleList([ Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), ]) self.condition_type = condition_type cond_channels = out_channels if condition_type == 'film': # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias cond_channels = out_channels * 2 self.cond_encoder = nn.Sequential( nn.Mish(), nn.Linear(cond_dim, cond_channels), Rearrange('batch t -> batch t 1'), ) elif condition_type == 'add': self.cond_encoder = nn.Sequential( nn.Mish(), nn.Linear(cond_dim, out_channels), Rearrange('batch t -> batch t 1'), ) elif condition_type == 'cross_attention_add': self.cond_encoder = CrossAttention(in_channels, cond_dim, out_channels) elif condition_type == 'cross_attention_film': cond_channels = out_channels * 2 self.cond_encoder = CrossAttention(in_channels, cond_dim, cond_channels) elif condition_type == 'mlp_film': cond_channels = out_channels * 2 self.cond_encoder = nn.Sequential( nn.Mish(), nn.Linear(cond_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_channels), Rearrange('batch t -> batch t 1'), ) else: raise NotImplementedError(f"condition_type {condition_type} not implemented") self.out_channels = out_channels # make sure dimensions compatible self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ if in_channels != out_channels else nn.Identity() def forward(self, x, cond=None): ''' x : [ batch_size x in_channels x horizon ] cond : [ batch_size x cond_dim] returns: out : [ batch_size x out_channels x horizon ] ''' out = self.blocks[0](x) if cond is not None: if self.condition_type == 'film': embed = self.cond_encoder(cond) embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias elif self.condition_type == 'add': embed = self.cond_encoder(cond) out = out + embed elif self.condition_type == 'cross_attention_add': embed = self.cond_encoder(x.permute(0, 2, 1), cond) embed = embed.permute(0, 2, 1) # [batch_size, out_channels, horizon] out = out + embed elif self.condition_type == 'cross_attention_film': embed = self.cond_encoder(x.permute(0, 2, 1), cond) embed = embed.permute(0, 2, 1) embed = embed.reshape(embed.shape[0], 2, self.out_channels, -1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias elif self.condition_type == 'mlp_film': embed = self.cond_encoder(cond) embed = embed.reshape(embed.shape[0], 2, self.out_channels, -1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias else: raise NotImplementedError(f"condition_type {self.condition_type} not implemented") out = self.blocks[1](out) out = out + self.residual_conv(x) return out class ConditionalUnet1D(nn.Module): def __init__(self, input_dim, local_cond_dim=None, global_cond_dim=None, diffusion_step_embed_dim=256, down_dims=[256,512,1024], kernel_size=3, n_groups=8, condition_type='film', use_down_condition=True, use_mid_condition=True, use_up_condition=True, ): super().__init__() self.condition_type = condition_type self.use_down_condition = use_down_condition self.use_mid_condition = use_mid_condition self.use_up_condition = use_up_condition all_dims = [input_dim] + list(down_dims) start_dim = down_dims[0] dsed = diffusion_step_embed_dim diffusion_step_encoder = nn.Sequential( SinusoidalPosEmb(dsed), nn.Linear(dsed, dsed * 4), nn.Mish(), nn.Linear(dsed * 4, dsed), ) cond_dim = dsed if global_cond_dim is not None: cond_dim += global_cond_dim in_out = list(zip(all_dims[:-1], all_dims[1:])) local_cond_encoder = None if local_cond_dim is not None: _, dim_out = in_out[0] dim_in = local_cond_dim local_cond_encoder = nn.ModuleList([ # down encoder ConditionalResidualBlock1D( dim_in, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type), # up encoder ConditionalResidualBlock1D( dim_in, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type) ]) mid_dim = all_dims[-1] self.mid_modules = nn.ModuleList([ ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type ), ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type ), ]) down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) down_modules.append(nn.ModuleList([ ConditionalResidualBlock1D( dim_in, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type), ConditionalResidualBlock1D( dim_out, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type), Downsample1d(dim_out) if not is_last else nn.Identity() ])) up_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) up_modules.append(nn.ModuleList([ ConditionalResidualBlock1D( dim_out*2, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type), ConditionalResidualBlock1D( dim_in, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type), Upsample1d(dim_in) if not is_last else nn.Identity() ])) final_conv = nn.Sequential( Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), nn.Conv1d(start_dim, input_dim, 1), ) self.diffusion_step_encoder = diffusion_step_encoder self.local_cond_encoder = local_cond_encoder self.up_modules = up_modules self.down_modules = down_modules self.final_conv = final_conv logger.info( "number of parameters: %e", sum(p.numel() for p in self.parameters()) ) def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], local_cond=None, global_cond=None, **kwargs): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step local_cond: (B,T,local_cond_dim) global_cond: (B,global_cond_dim) output: (B,T,input_dim) """ sample = einops.rearrange(sample, 'b h t -> b t h') # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) timestep_embed = self.diffusion_step_encoder(timesteps) if global_cond is not None: if self.condition_type == 'cross_attention': timestep_embed = timestep_embed.unsqueeze(1).expand(-1, global_cond.shape[1], -1) global_feature = torch.cat([timestep_embed, global_cond], axis=-1) # encode local features h_local = list() if local_cond is not None: local_cond = einops.rearrange(local_cond, 'b h t -> b t h') resnet, resnet2 = self.local_cond_encoder x = resnet(local_cond, global_feature) h_local.append(x) x = resnet2(local_cond, global_feature) h_local.append(x) x = sample h = [] for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): if self.use_down_condition: x = resnet(x, global_feature) if idx == 0 and len(h_local) > 0: x = x + h_local[0] x = resnet2(x, global_feature) else: x = resnet(x) if idx == 0 and len(h_local) > 0: x = x + h_local[0] x = resnet2(x) h.append(x) x = downsample(x) for mid_module in self.mid_modules: if self.use_mid_condition: x = mid_module(x, global_feature) else: x = mid_module(x) for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) if self.use_up_condition: x = resnet(x, global_feature) if idx == len(self.up_modules) and len(h_local) > 0: x = x + h_local[1] x = resnet2(x, global_feature) else: x = resnet(x) if idx == len(self.up_modules) and len(h_local) > 0: x = x + h_local[1] x = resnet2(x) x = upsample(x) x = self.final_conv(x) x = einops.rearrange(x, 'b t h -> b h t') return x