|
from typing import Dict, List |
|
import torch |
|
import numpy as np |
|
import h5py |
|
from tqdm import tqdm |
|
import copy |
|
from equi_diffpo.common.pytorch_util import dict_apply |
|
from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer |
|
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer |
|
from equi_diffpo.model.common.rotation_transformer import RotationTransformer |
|
from equi_diffpo.common.replay_buffer import ReplayBuffer |
|
from equi_diffpo.common.sampler import ( |
|
SequenceSampler, get_val_mask, downsample_mask) |
|
from equi_diffpo.common.normalize_util import ( |
|
robomimic_abs_action_only_normalizer_from_stat, |
|
robomimic_abs_action_only_dual_arm_normalizer_from_stat, |
|
get_identity_normalizer_from_stat, |
|
array_to_stats |
|
) |
|
|
|
class RobomimicReplayLowdimDataset(BaseLowdimDataset): |
|
def __init__(self, |
|
dataset_path: str, |
|
horizon=1, |
|
pad_before=0, |
|
pad_after=0, |
|
obs_keys: List[str]=[ |
|
'object', |
|
'robot0_eef_pos', |
|
'robot0_eef_quat', |
|
'robot0_gripper_qpos'], |
|
abs_action=False, |
|
rotation_rep='rotation_6d', |
|
use_legacy_normalizer=False, |
|
seed=42, |
|
val_ratio=0.0, |
|
max_train_episodes=None, |
|
n_demo=100 |
|
): |
|
obs_keys = list(obs_keys) |
|
rotation_transformer = RotationTransformer( |
|
from_rep='axis_angle', to_rep=rotation_rep) |
|
|
|
replay_buffer = ReplayBuffer.create_empty_numpy() |
|
with h5py.File(dataset_path) as file: |
|
demos = file['data'] |
|
for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"): |
|
demo = demos[f'demo_{i}'] |
|
episode = _data_to_obs( |
|
raw_obs=demo['obs'], |
|
raw_actions=demo['actions'][:].astype(np.float32), |
|
obs_keys=obs_keys, |
|
abs_action=abs_action, |
|
rotation_transformer=rotation_transformer) |
|
replay_buffer.add_episode(episode) |
|
|
|
val_mask = get_val_mask( |
|
n_episodes=replay_buffer.n_episodes, |
|
val_ratio=val_ratio, |
|
seed=seed) |
|
train_mask = ~val_mask |
|
train_mask = downsample_mask( |
|
mask=train_mask, |
|
max_n=max_train_episodes, |
|
seed=seed) |
|
|
|
sampler = SequenceSampler( |
|
replay_buffer=replay_buffer, |
|
sequence_length=horizon, |
|
pad_before=pad_before, |
|
pad_after=pad_after, |
|
episode_mask=train_mask) |
|
|
|
self.replay_buffer = replay_buffer |
|
self.sampler = sampler |
|
self.abs_action = abs_action |
|
self.train_mask = train_mask |
|
self.horizon = horizon |
|
self.pad_before = pad_before |
|
self.pad_after = pad_after |
|
self.use_legacy_normalizer = use_legacy_normalizer |
|
|
|
def get_validation_dataset(self): |
|
val_set = copy.copy(self) |
|
val_set.sampler = SequenceSampler( |
|
replay_buffer=self.replay_buffer, |
|
sequence_length=self.horizon, |
|
pad_before=self.pad_before, |
|
pad_after=self.pad_after, |
|
episode_mask=~self.train_mask |
|
) |
|
val_set.train_mask = ~self.train_mask |
|
return val_set |
|
|
|
def get_normalizer(self, **kwargs) -> LinearNormalizer: |
|
normalizer = LinearNormalizer() |
|
|
|
|
|
stat = array_to_stats(self.replay_buffer['action']) |
|
if self.abs_action: |
|
if stat['mean'].shape[-1] > 10: |
|
|
|
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat) |
|
else: |
|
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat) |
|
|
|
if self.use_legacy_normalizer: |
|
this_normalizer = normalizer_from_stat(stat) |
|
else: |
|
|
|
this_normalizer = get_identity_normalizer_from_stat(stat) |
|
normalizer['action'] = this_normalizer |
|
|
|
|
|
obs_stat = array_to_stats(self.replay_buffer['obs']) |
|
|
|
|
|
normalizer['obs'] = normalizer_from_stat(obs_stat) |
|
return normalizer |
|
|
|
def get_all_actions(self) -> torch.Tensor: |
|
return torch.from_numpy(self.replay_buffer['action']) |
|
|
|
def __len__(self): |
|
return len(self.sampler) |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
data = self.sampler.sample_sequence(idx) |
|
torch_data = dict_apply(data, torch.from_numpy) |
|
return torch_data |
|
|
|
def normalizer_from_stat(stat): |
|
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) |
|
scale = np.full_like(stat['max'], fill_value=1/max_abs) |
|
offset = np.zeros_like(stat['max']) |
|
return SingleFieldLinearNormalizer.create_manual( |
|
scale=scale, |
|
offset=offset, |
|
input_stats_dict=stat |
|
) |
|
|
|
def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer): |
|
obs = np.concatenate([ |
|
raw_obs[key] for key in obs_keys |
|
], axis=-1).astype(np.float32) |
|
|
|
if abs_action: |
|
is_dual_arm = False |
|
if raw_actions.shape[-1] == 14: |
|
|
|
raw_actions = raw_actions.reshape(-1,2,7) |
|
is_dual_arm = True |
|
|
|
pos = raw_actions[...,:3] |
|
rot = raw_actions[...,3:6] |
|
gripper = raw_actions[...,6:] |
|
rot = rotation_transformer.forward(rot) |
|
raw_actions = np.concatenate([ |
|
pos, rot, gripper |
|
], axis=-1).astype(np.float32) |
|
|
|
if is_dual_arm: |
|
raw_actions = raw_actions.reshape(-1,20) |
|
|
|
data = { |
|
'obs': obs, |
|
'action': raw_actions |
|
} |
|
return data |
|
|