from typing import Dict import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from termcolor import cprint import copy import time import pytorch3d.ops as torch3d_ops from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin from equi_diffpo.model.common.normalizer import LinearNormalizer from equi_diffpo.model.diffusion.dp3_conditional_unet1d import ConditionalUnet1D from equi_diffpo.model.diffusion.mask_generator import LowdimMaskGenerator from equi_diffpo.common.pytorch_util import dict_apply from equi_diffpo.model.vision.pointnet_extractor import DP3Encoder class BasePolicy(ModuleAttrMixin): # init accepts keyword argument shape_meta, see config/task/*_image.yaml def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ obs_dict: str: B,To,* return: B,Ta,Da """ raise NotImplementedError() # reset state for stateful policies def reset(self): pass # ========== training =========== # no standard training interface except setting normalizer def set_normalizer(self, normalizer: LinearNormalizer): raise NotImplementedError() class DP3(BasePolicy): def __init__(self, shape_meta: dict, noise_scheduler: DDPMScheduler, horizon, n_action_steps, n_obs_steps, num_inference_steps=None, obs_as_global_cond=True, diffusion_step_embed_dim=256, down_dims=(256,512,1024), kernel_size=5, n_groups=8, condition_type="film", use_down_condition=True, use_mid_condition=True, use_up_condition=True, encoder_output_dim=256, crop_shape=None, use_pc_color=False, pointnet_type="pointnet", pointcloud_encoder_cfg=None, # parameters passed to step **kwargs): super().__init__() self.condition_type = condition_type # parse shape_meta action_shape = shape_meta['action']['shape'] self.action_shape = action_shape if len(action_shape) == 1: action_dim = action_shape[0] elif len(action_shape) == 2: # use multiple hands action_dim = action_shape[0] * action_shape[1] else: raise NotImplementedError(f"Unsupported action shape {action_shape}") obs_shape_meta = shape_meta['obs'] obs_dict = dict_apply(obs_shape_meta, lambda x: x['shape']) obs_encoder = DP3Encoder(observation_space=obs_dict, img_crop_shape=crop_shape, out_channel=encoder_output_dim, pointcloud_encoder_cfg=pointcloud_encoder_cfg, use_pc_color=use_pc_color, pointnet_type=pointnet_type, ) # create diffusion model obs_feature_dim = obs_encoder.output_shape() input_dim = action_dim + obs_feature_dim global_cond_dim = None if obs_as_global_cond: input_dim = action_dim if "cross_attention" in self.condition_type: global_cond_dim = obs_feature_dim else: global_cond_dim = obs_feature_dim * n_obs_steps self.use_pc_color = use_pc_color self.pointnet_type = pointnet_type cprint(f"[DiffusionUnetHybridPointcloudPolicy] use_pc_color: {self.use_pc_color}", "yellow") cprint(f"[DiffusionUnetHybridPointcloudPolicy] pointnet_type: {self.pointnet_type}", "yellow") model = ConditionalUnet1D( input_dim=input_dim, local_cond_dim=None, global_cond_dim=global_cond_dim, diffusion_step_embed_dim=diffusion_step_embed_dim, down_dims=down_dims, kernel_size=kernel_size, n_groups=n_groups, condition_type=condition_type, use_down_condition=use_down_condition, use_mid_condition=use_mid_condition, use_up_condition=use_up_condition, ) self.obs_encoder = obs_encoder self.model = model self.noise_scheduler = noise_scheduler self.noise_scheduler_pc = copy.deepcopy(noise_scheduler) self.mask_generator = LowdimMaskGenerator( action_dim=action_dim, obs_dim=0 if obs_as_global_cond else obs_feature_dim, max_n_obs_steps=n_obs_steps, fix_obs_steps=True, action_visible=False ) self.normalizer = LinearNormalizer() self.horizon = horizon self.obs_feature_dim = obs_feature_dim self.action_dim = action_dim self.n_action_steps = n_action_steps self.n_obs_steps = n_obs_steps self.obs_as_global_cond = obs_as_global_cond self.kwargs = kwargs if num_inference_steps is None: num_inference_steps = noise_scheduler.config.num_train_timesteps self.num_inference_steps = num_inference_steps # print_params(self) # ========= inference ============ def conditional_sample(self, condition_data, condition_mask, condition_data_pc=None, condition_mask_pc=None, local_cond=None, global_cond=None, generator=None, # keyword arguments to scheduler.step **kwargs ): model = self.model scheduler = self.noise_scheduler trajectory = torch.randn( size=condition_data.shape, dtype=condition_data.dtype, device=condition_data.device) # set step values scheduler.set_timesteps(self.num_inference_steps) for t in scheduler.timesteps: # 1. apply conditioning trajectory[condition_mask] = condition_data[condition_mask] model_output = model(sample=trajectory, timestep=t, local_cond=local_cond, global_cond=global_cond) # 3. compute previous image: x_t -> x_t-1 trajectory = scheduler.step( model_output, t, trajectory, ).prev_sample # finally make sure conditioning is enforced trajectory[condition_mask] = condition_data[condition_mask] return trajectory def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ obs_dict: must include "obs" key result: must include "action" key """ if 'robot0_eye_in_hand_image' in obs_dict: del obs_dict['robot0_eye_in_hand_image'] if 'agentview_image' in obs_dict: del obs_dict['agentview_image'] # normalize input nobs = self.normalizer.normalize(obs_dict) # this_n_point_cloud = nobs['imagin_robot'][..., :3] # only use coordinate if not self.use_pc_color: nobs['point_cloud'] = nobs['point_cloud'][..., :3] this_n_point_cloud = nobs['point_cloud'] value = next(iter(nobs.values())) B, To = value.shape[:2] T = self.horizon Da = self.action_dim Do = self.obs_feature_dim To = self.n_obs_steps # build input device = self.device dtype = self.dtype # handle different ways of passing observation local_cond = None global_cond = None if self.obs_as_global_cond: # condition through global feature this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) if "cross_attention" in self.condition_type: # treat as a sequence global_cond = nobs_features.reshape(B, self.n_obs_steps, -1) else: # reshape back to B, Do global_cond = nobs_features.reshape(B, -1) # empty data for action cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) else: # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, T, Do nobs_features = nobs_features.reshape(B, To, -1) cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) cond_data[:,:To,Da:] = nobs_features cond_mask[:,:To,Da:] = True # run sampling nsample = self.conditional_sample( cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs) # unnormalize prediction naction_pred = nsample[...,:Da] action_pred = self.normalizer['action'].unnormalize(naction_pred) # get action start = To - 1 end = start + self.n_action_steps action = action_pred[:,start:end] # get prediction result = { 'action': action, 'action_pred': action_pred, } return result # ========= training ============ def set_normalizer(self, normalizer: LinearNormalizer): self.normalizer.load_state_dict(normalizer.state_dict()) def compute_loss(self, batch): if 'robot0_eye_in_hand_image' in batch['obs']: del batch['obs']['robot0_eye_in_hand_image'] # normalize input nobs = self.normalizer.normalize(batch['obs']) nactions = self.normalizer['action'].normalize(batch['action']) if not self.use_pc_color: nobs['point_cloud'] = nobs['point_cloud'][..., :3] batch_size = nactions.shape[0] horizon = nactions.shape[1] # handle different ways of passing observation local_cond = None global_cond = None trajectory = nactions cond_data = trajectory if self.obs_as_global_cond: # reshape B, T, ... to B*T this_nobs = dict_apply(nobs, lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) if "cross_attention" in self.condition_type: # treat as a sequence global_cond = nobs_features.reshape(batch_size, self.n_obs_steps, -1) else: # reshape back to B, Do global_cond = nobs_features.reshape(batch_size, -1) # this_n_point_cloud = this_nobs['imagin_robot'].reshape(batch_size,-1, *this_nobs['imagin_robot'].shape[1:]) this_n_point_cloud = this_nobs['point_cloud'].reshape(batch_size,-1, *this_nobs['point_cloud'].shape[1:]) this_n_point_cloud = this_n_point_cloud[..., :3] else: # reshape B, T, ... to B*T this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, T, Do nobs_features = nobs_features.reshape(batch_size, horizon, -1) cond_data = torch.cat([nactions, nobs_features], dim=-1) trajectory = cond_data.detach() # generate impainting mask condition_mask = self.mask_generator(trajectory.shape) # Sample noise that we'll add to the images noise = torch.randn(trajectory.shape, device=trajectory.device) bsz = trajectory.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device ).long() # Add noise to the clean images according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_trajectory = self.noise_scheduler.add_noise( trajectory, noise, timesteps) # compute loss mask loss_mask = ~condition_mask # apply conditioning noisy_trajectory[condition_mask] = cond_data[condition_mask] # Predict the noise residual pred = self.model(sample=noisy_trajectory, timestep=timesteps, local_cond=local_cond, global_cond=global_cond) pred_type = self.noise_scheduler.config.prediction_type if pred_type == 'epsilon': target = noise elif pred_type == 'sample': target = trajectory elif pred_type == 'v_prediction': # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py # https://github.com/huggingface/diffusers/blob/v0.11.1-patch/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py # sigma = self.noise_scheduler.sigmas[timesteps] # alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma) self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device) self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device) alpha_t, sigma_t = self.noise_scheduler.alpha_t[timesteps], self.noise_scheduler.sigma_t[timesteps] alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1) sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1) v_t = alpha_t * noise - sigma_t * trajectory target = v_t else: raise ValueError(f"Unsupported prediction type {pred_type}") loss = F.mse_loss(pred, target, reduction='none') loss = loss * loss_mask.type(loss.dtype) loss = reduce(loss, 'b ... -> b (...)', 'mean') loss = loss.mean() loss_dict = { 'bc_loss': loss.item(), } # print(f"t2-t1: {t2-t1:.3f}") # print(f"t3-t2: {t3-t2:.3f}") # print(f"t4-t3: {t4-t3:.3f}") # print(f"t5-t4: {t5-t4:.3f}") # print(f"t6-t5: {t6-t5:.3f}") return loss, loss_dict