|
from typing import Dict |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce |
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
from termcolor import cprint |
|
import copy |
|
import time |
|
import pytorch3d.ops as torch3d_ops |
|
|
|
from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin |
|
from equi_diffpo.model.common.normalizer import LinearNormalizer |
|
from equi_diffpo.model.diffusion.dp3_conditional_unet1d import ConditionalUnet1D |
|
from equi_diffpo.model.diffusion.mask_generator import LowdimMaskGenerator |
|
from equi_diffpo.common.pytorch_util import dict_apply |
|
from equi_diffpo.model.vision.pointnet_extractor import DP3Encoder |
|
|
|
class BasePolicy(ModuleAttrMixin): |
|
|
|
|
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
obs_dict: |
|
str: B,To,* |
|
return: B,Ta,Da |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def reset(self): |
|
pass |
|
|
|
|
|
|
|
def set_normalizer(self, normalizer: LinearNormalizer): |
|
raise NotImplementedError() |
|
|
|
class DP3(BasePolicy): |
|
def __init__(self, |
|
shape_meta: dict, |
|
noise_scheduler: DDPMScheduler, |
|
horizon, |
|
n_action_steps, |
|
n_obs_steps, |
|
num_inference_steps=None, |
|
obs_as_global_cond=True, |
|
diffusion_step_embed_dim=256, |
|
down_dims=(256,512,1024), |
|
kernel_size=5, |
|
n_groups=8, |
|
condition_type="film", |
|
use_down_condition=True, |
|
use_mid_condition=True, |
|
use_up_condition=True, |
|
encoder_output_dim=256, |
|
crop_shape=None, |
|
use_pc_color=False, |
|
pointnet_type="pointnet", |
|
pointcloud_encoder_cfg=None, |
|
|
|
**kwargs): |
|
super().__init__() |
|
|
|
self.condition_type = condition_type |
|
|
|
|
|
action_shape = shape_meta['action']['shape'] |
|
self.action_shape = action_shape |
|
if len(action_shape) == 1: |
|
action_dim = action_shape[0] |
|
elif len(action_shape) == 2: |
|
action_dim = action_shape[0] * action_shape[1] |
|
else: |
|
raise NotImplementedError(f"Unsupported action shape {action_shape}") |
|
|
|
obs_shape_meta = shape_meta['obs'] |
|
obs_dict = dict_apply(obs_shape_meta, lambda x: x['shape']) |
|
|
|
|
|
obs_encoder = DP3Encoder(observation_space=obs_dict, |
|
img_crop_shape=crop_shape, |
|
out_channel=encoder_output_dim, |
|
pointcloud_encoder_cfg=pointcloud_encoder_cfg, |
|
use_pc_color=use_pc_color, |
|
pointnet_type=pointnet_type, |
|
) |
|
|
|
|
|
obs_feature_dim = obs_encoder.output_shape() |
|
input_dim = action_dim + obs_feature_dim |
|
global_cond_dim = None |
|
if obs_as_global_cond: |
|
input_dim = action_dim |
|
if "cross_attention" in self.condition_type: |
|
global_cond_dim = obs_feature_dim |
|
else: |
|
global_cond_dim = obs_feature_dim * n_obs_steps |
|
|
|
|
|
self.use_pc_color = use_pc_color |
|
self.pointnet_type = pointnet_type |
|
cprint(f"[DiffusionUnetHybridPointcloudPolicy] use_pc_color: {self.use_pc_color}", "yellow") |
|
cprint(f"[DiffusionUnetHybridPointcloudPolicy] pointnet_type: {self.pointnet_type}", "yellow") |
|
|
|
|
|
|
|
model = ConditionalUnet1D( |
|
input_dim=input_dim, |
|
local_cond_dim=None, |
|
global_cond_dim=global_cond_dim, |
|
diffusion_step_embed_dim=diffusion_step_embed_dim, |
|
down_dims=down_dims, |
|
kernel_size=kernel_size, |
|
n_groups=n_groups, |
|
condition_type=condition_type, |
|
use_down_condition=use_down_condition, |
|
use_mid_condition=use_mid_condition, |
|
use_up_condition=use_up_condition, |
|
) |
|
|
|
self.obs_encoder = obs_encoder |
|
self.model = model |
|
self.noise_scheduler = noise_scheduler |
|
|
|
|
|
self.noise_scheduler_pc = copy.deepcopy(noise_scheduler) |
|
self.mask_generator = LowdimMaskGenerator( |
|
action_dim=action_dim, |
|
obs_dim=0 if obs_as_global_cond else obs_feature_dim, |
|
max_n_obs_steps=n_obs_steps, |
|
fix_obs_steps=True, |
|
action_visible=False |
|
) |
|
|
|
self.normalizer = LinearNormalizer() |
|
self.horizon = horizon |
|
self.obs_feature_dim = obs_feature_dim |
|
self.action_dim = action_dim |
|
self.n_action_steps = n_action_steps |
|
self.n_obs_steps = n_obs_steps |
|
self.obs_as_global_cond = obs_as_global_cond |
|
self.kwargs = kwargs |
|
|
|
if num_inference_steps is None: |
|
num_inference_steps = noise_scheduler.config.num_train_timesteps |
|
self.num_inference_steps = num_inference_steps |
|
|
|
|
|
|
|
|
|
|
|
def conditional_sample(self, |
|
condition_data, condition_mask, |
|
condition_data_pc=None, condition_mask_pc=None, |
|
local_cond=None, global_cond=None, |
|
generator=None, |
|
|
|
**kwargs |
|
): |
|
model = self.model |
|
scheduler = self.noise_scheduler |
|
|
|
|
|
trajectory = torch.randn( |
|
size=condition_data.shape, |
|
dtype=condition_data.dtype, |
|
device=condition_data.device) |
|
|
|
|
|
scheduler.set_timesteps(self.num_inference_steps) |
|
|
|
|
|
for t in scheduler.timesteps: |
|
|
|
trajectory[condition_mask] = condition_data[condition_mask] |
|
|
|
|
|
model_output = model(sample=trajectory, |
|
timestep=t, |
|
local_cond=local_cond, global_cond=global_cond) |
|
|
|
|
|
trajectory = scheduler.step( |
|
model_output, t, trajectory, ).prev_sample |
|
|
|
|
|
|
|
trajectory[condition_mask] = condition_data[condition_mask] |
|
|
|
|
|
return trajectory |
|
|
|
|
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
obs_dict: must include "obs" key |
|
result: must include "action" key |
|
""" |
|
if 'robot0_eye_in_hand_image' in obs_dict: |
|
del obs_dict['robot0_eye_in_hand_image'] |
|
if 'agentview_image' in obs_dict: |
|
del obs_dict['agentview_image'] |
|
|
|
nobs = self.normalizer.normalize(obs_dict) |
|
|
|
if not self.use_pc_color: |
|
nobs['point_cloud'] = nobs['point_cloud'][..., :3] |
|
this_n_point_cloud = nobs['point_cloud'] |
|
|
|
|
|
value = next(iter(nobs.values())) |
|
B, To = value.shape[:2] |
|
T = self.horizon |
|
Da = self.action_dim |
|
Do = self.obs_feature_dim |
|
To = self.n_obs_steps |
|
|
|
|
|
device = self.device |
|
dtype = self.dtype |
|
|
|
|
|
local_cond = None |
|
global_cond = None |
|
if self.obs_as_global_cond: |
|
|
|
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) |
|
nobs_features = self.obs_encoder(this_nobs) |
|
if "cross_attention" in self.condition_type: |
|
|
|
global_cond = nobs_features.reshape(B, self.n_obs_steps, -1) |
|
else: |
|
|
|
global_cond = nobs_features.reshape(B, -1) |
|
|
|
cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype) |
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) |
|
else: |
|
|
|
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) |
|
nobs_features = self.obs_encoder(this_nobs) |
|
|
|
nobs_features = nobs_features.reshape(B, To, -1) |
|
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) |
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) |
|
cond_data[:,:To,Da:] = nobs_features |
|
cond_mask[:,:To,Da:] = True |
|
|
|
|
|
nsample = self.conditional_sample( |
|
cond_data, |
|
cond_mask, |
|
local_cond=local_cond, |
|
global_cond=global_cond, |
|
**self.kwargs) |
|
|
|
|
|
naction_pred = nsample[...,:Da] |
|
action_pred = self.normalizer['action'].unnormalize(naction_pred) |
|
|
|
|
|
start = To - 1 |
|
end = start + self.n_action_steps |
|
action = action_pred[:,start:end] |
|
|
|
|
|
|
|
|
|
result = { |
|
'action': action, |
|
'action_pred': action_pred, |
|
} |
|
|
|
return result |
|
|
|
|
|
def set_normalizer(self, normalizer: LinearNormalizer): |
|
self.normalizer.load_state_dict(normalizer.state_dict()) |
|
|
|
def compute_loss(self, batch): |
|
if 'robot0_eye_in_hand_image' in batch['obs']: |
|
del batch['obs']['robot0_eye_in_hand_image'] |
|
|
|
|
|
nobs = self.normalizer.normalize(batch['obs']) |
|
nactions = self.normalizer['action'].normalize(batch['action']) |
|
|
|
if not self.use_pc_color: |
|
nobs['point_cloud'] = nobs['point_cloud'][..., :3] |
|
|
|
batch_size = nactions.shape[0] |
|
horizon = nactions.shape[1] |
|
|
|
|
|
local_cond = None |
|
global_cond = None |
|
trajectory = nactions |
|
cond_data = trajectory |
|
|
|
|
|
|
|
if self.obs_as_global_cond: |
|
|
|
this_nobs = dict_apply(nobs, |
|
lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:])) |
|
nobs_features = self.obs_encoder(this_nobs) |
|
|
|
if "cross_attention" in self.condition_type: |
|
|
|
global_cond = nobs_features.reshape(batch_size, self.n_obs_steps, -1) |
|
else: |
|
|
|
global_cond = nobs_features.reshape(batch_size, -1) |
|
|
|
this_n_point_cloud = this_nobs['point_cloud'].reshape(batch_size,-1, *this_nobs['point_cloud'].shape[1:]) |
|
this_n_point_cloud = this_n_point_cloud[..., :3] |
|
else: |
|
|
|
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) |
|
nobs_features = self.obs_encoder(this_nobs) |
|
|
|
nobs_features = nobs_features.reshape(batch_size, horizon, -1) |
|
cond_data = torch.cat([nactions, nobs_features], dim=-1) |
|
trajectory = cond_data.detach() |
|
|
|
|
|
|
|
condition_mask = self.mask_generator(trajectory.shape) |
|
|
|
|
|
noise = torch.randn(trajectory.shape, device=trajectory.device) |
|
|
|
|
|
bsz = trajectory.shape[0] |
|
|
|
timesteps = torch.randint( |
|
0, self.noise_scheduler.config.num_train_timesteps, |
|
(bsz,), device=trajectory.device |
|
).long() |
|
|
|
|
|
|
|
noisy_trajectory = self.noise_scheduler.add_noise( |
|
trajectory, noise, timesteps) |
|
|
|
|
|
|
|
|
|
loss_mask = ~condition_mask |
|
|
|
|
|
noisy_trajectory[condition_mask] = cond_data[condition_mask] |
|
|
|
|
|
|
|
pred = self.model(sample=noisy_trajectory, |
|
timestep=timesteps, |
|
local_cond=local_cond, |
|
global_cond=global_cond) |
|
|
|
|
|
pred_type = self.noise_scheduler.config.prediction_type |
|
if pred_type == 'epsilon': |
|
target = noise |
|
elif pred_type == 'sample': |
|
target = trajectory |
|
elif pred_type == 'v_prediction': |
|
|
|
|
|
|
|
|
|
self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device) |
|
self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device) |
|
alpha_t, sigma_t = self.noise_scheduler.alpha_t[timesteps], self.noise_scheduler.sigma_t[timesteps] |
|
alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1) |
|
sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1) |
|
v_t = alpha_t * noise - sigma_t * trajectory |
|
target = v_t |
|
else: |
|
raise ValueError(f"Unsupported prediction type {pred_type}") |
|
|
|
loss = F.mse_loss(pred, target, reduction='none') |
|
loss = loss * loss_mask.type(loss.dtype) |
|
loss = reduce(loss, 'b ... -> b (...)', 'mean') |
|
loss = loss.mean() |
|
|
|
|
|
loss_dict = { |
|
'bc_loss': loss.item(), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss, loss_dict |