# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from functools import partial import torch import torch.nn as nn from src.models.utils.modules import ACBlock as Block from src.models.utils.modules import build_action_block_causal_attention_mask from src.utils.tensors import trunc_normal_ class VisionTransformerPredictorAC(nn.Module): """Action Conditioned Vision Transformer Predictor""" def __init__( self, img_size=(224, 224), patch_size=16, num_frames=1, tubelet_size=2, embed_dim=768, predictor_embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, init_std=0.02, uniform_power=True, use_silu=False, wide_silu=True, is_frame_causal=True, use_activation_checkpointing=False, use_rope=True, action_embed_dim=7, use_extrinsics=False, **kwargs ): super().__init__() self.is_frame_causal = is_frame_causal self.use_extrinsics = use_extrinsics # Map input to predictor dimension self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True) # Determine positional embedding if type(img_size) is int: img_size = (img_size, img_size) self.img_height, self.img_width = img_size self.patch_size = patch_size # -- self.num_frames = num_frames self.tubelet_size = tubelet_size self.is_video = num_frames > 1 self.grid_height = img_size[0] // self.patch_size self.grid_width = img_size[1] // self.patch_size self.use_activation_checkpointing = use_activation_checkpointing dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule # Position embedding self.uniform_power = uniform_power # Attention Blocks self.use_rope = use_rope self.predictor_blocks = nn.ModuleList( [ Block( use_rope=use_rope, grid_size=self.grid_height, dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, act_layer=nn.SiLU if use_silu else nn.GELU, wide_silu=wide_silu, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth) ] ) # Normalize & project back to input dimension self.predictor_norm = norm_layer(predictor_embed_dim) self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) # ------ initialize weights self.init_std = init_std self.apply(self._init_weights) self._rescale_blocks() attn_mask = None if self.is_frame_causal: grid_depth = self.num_frames // self.tubelet_size grid_height = self.img_height // self.patch_size grid_width = self.img_width // self.patch_size attn_mask = build_action_block_causal_attention_mask( grid_depth, grid_height, grid_width, add_tokens=3 if use_extrinsics else 2 ) self.attn_mask = attn_mask def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _rescale_blocks(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.predictor_blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def forward(self, x, actions, states, extrinsics=None): """ :param x: context tokens """ # Map tokens to pedictor dimensions x = self.predictor_embed(x) B, N_ctxt, D = x.size() T = N_ctxt // (self.grid_height * self.grid_width) # Interleave action tokens s = self.state_encoder(states).unsqueeze(2) a = self.action_encoder(actions).unsqueeze(2) x = x.view(B, T, self.grid_height * self.grid_width, D) # [B, T, H*W, D] if self.use_extrinsics: e = self.extrinsics_encoder(extrinsics).unsqueeze(2) x = torch.cat([a, s, e, x], dim=2).flatten(1, 2) # [B, T*(H*W+3), D] else: x = torch.cat([a, s, x], dim=2).flatten(1, 2) # [B, T*(H*W+2), D] cond_tokens = 3 if self.use_extrinsics else 2 attn_mask = self.attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True) # Fwd prop for i, blk in enumerate(self.predictor_blocks): if self.use_activation_checkpointing: x = torch.utils.checkpoint.checkpoint( blk, x, mask=None, attn_mask=attn_mask, T=T, H=self.grid_height, W=self.grid_width, action_tokens=cond_tokens, use_reentrant=False, ) else: x = blk( x, mask=None, attn_mask=attn_mask, T=T, H=self.grid_height, W=self.grid_width, action_tokens=cond_tokens, ) # Split out action and frame tokens x = x.view(B, T, cond_tokens + self.grid_height * self.grid_width, D) # [B, T, K+H*W, D] x = x[:, :, cond_tokens:, :].flatten(1, 2) x = self.predictor_norm(x) x = self.predictor_proj(x) return x def vit_ac_predictor(**kwargs): model = VisionTransformerPredictorAC( mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model