from typing import Dict, List import torch import numpy as np import h5py from tqdm import tqdm import copy from equi_diffpo.common.pytorch_util import dict_apply from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer from equi_diffpo.model.common.rotation_transformer import RotationTransformer from equi_diffpo.common.replay_buffer import ReplayBuffer from equi_diffpo.common.sampler import ( SequenceSampler, get_val_mask, downsample_mask) from equi_diffpo.common.normalize_util import ( robomimic_abs_action_only_normalizer_from_stat, robomimic_abs_action_only_dual_arm_normalizer_from_stat, get_identity_normalizer_from_stat, array_to_stats ) class RobomimicReplayLowdimDataset(BaseLowdimDataset): def __init__(self, dataset_path: str, horizon=1, pad_before=0, pad_after=0, obs_keys: List[str]=[ 'object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'], abs_action=False, rotation_rep='rotation_6d', use_legacy_normalizer=False, seed=42, val_ratio=0.0, max_train_episodes=None, n_demo=100 ): obs_keys = list(obs_keys) rotation_transformer = RotationTransformer( from_rep='axis_angle', to_rep=rotation_rep) replay_buffer = ReplayBuffer.create_empty_numpy() with h5py.File(dataset_path) as file: demos = file['data'] for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"): demo = demos[f'demo_{i}'] episode = _data_to_obs( raw_obs=demo['obs'], raw_actions=demo['actions'][:].astype(np.float32), obs_keys=obs_keys, abs_action=abs_action, rotation_transformer=rotation_transformer) replay_buffer.add_episode(episode) val_mask = get_val_mask( n_episodes=replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) train_mask = ~val_mask train_mask = downsample_mask( mask=train_mask, max_n=max_train_episodes, seed=seed) sampler = SequenceSampler( replay_buffer=replay_buffer, sequence_length=horizon, pad_before=pad_before, pad_after=pad_after, episode_mask=train_mask) self.replay_buffer = replay_buffer self.sampler = sampler self.abs_action = abs_action self.train_mask = train_mask self.horizon = horizon self.pad_before = pad_before self.pad_after = pad_after self.use_legacy_normalizer = use_legacy_normalizer def get_validation_dataset(self): val_set = copy.copy(self) val_set.sampler = SequenceSampler( replay_buffer=self.replay_buffer, sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, episode_mask=~self.train_mask ) val_set.train_mask = ~self.train_mask return val_set def get_normalizer(self, **kwargs) -> LinearNormalizer: normalizer = LinearNormalizer() # action stat = array_to_stats(self.replay_buffer['action']) if self.abs_action: if stat['mean'].shape[-1] > 10: # dual arm this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat) else: this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat) if self.use_legacy_normalizer: this_normalizer = normalizer_from_stat(stat) else: # already normalized this_normalizer = get_identity_normalizer_from_stat(stat) normalizer['action'] = this_normalizer # aggregate obs stats obs_stat = array_to_stats(self.replay_buffer['obs']) normalizer['obs'] = normalizer_from_stat(obs_stat) return normalizer def get_all_actions(self) -> torch.Tensor: return torch.from_numpy(self.replay_buffer['action']) def __len__(self): return len(self.sampler) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: data = self.sampler.sample_sequence(idx) torch_data = dict_apply(data, torch.from_numpy) return torch_data def normalizer_from_stat(stat): max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) scale = np.full_like(stat['max'], fill_value=1/max_abs) offset = np.zeros_like(stat['max']) return SingleFieldLinearNormalizer.create_manual( scale=scale, offset=offset, input_stats_dict=stat ) def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer): obs = np.concatenate([ raw_obs[key] for key in obs_keys ], axis=-1).astype(np.float32) if abs_action: is_dual_arm = False if raw_actions.shape[-1] == 14: # dual arm raw_actions = raw_actions.reshape(-1,2,7) is_dual_arm = True pos = raw_actions[...,:3] rot = raw_actions[...,3:6] gripper = raw_actions[...,6:] rot = rotation_transformer.forward(rot) raw_actions = np.concatenate([ pos, rot, gripper ], axis=-1).astype(np.float32) if is_dual_arm: raw_actions = raw_actions.reshape(-1,20) data = { 'obs': obs, 'action': raw_actions } return data