File size: 5,853 Bytes
1501ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Dict, List
import torch
import numpy as np
import h5py
from tqdm import tqdm
import copy
from equi_diffpo.common.pytorch_util import dict_apply
from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
from equi_diffpo.common.replay_buffer import ReplayBuffer
from equi_diffpo.common.sampler import (
SequenceSampler, get_val_mask, downsample_mask)
from equi_diffpo.common.normalize_util import (
robomimic_abs_action_only_normalizer_from_stat,
robomimic_abs_action_only_dual_arm_normalizer_from_stat,
get_identity_normalizer_from_stat,
array_to_stats
)
class RobomimicReplayLowdimDataset(BaseLowdimDataset):
def __init__(self,
dataset_path: str,
horizon=1,
pad_before=0,
pad_after=0,
obs_keys: List[str]=[
'object',
'robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos'],
abs_action=False,
rotation_rep='rotation_6d',
use_legacy_normalizer=False,
seed=42,
val_ratio=0.0,
max_train_episodes=None,
n_demo=100
):
obs_keys = list(obs_keys)
rotation_transformer = RotationTransformer(
from_rep='axis_angle', to_rep=rotation_rep)
replay_buffer = ReplayBuffer.create_empty_numpy()
with h5py.File(dataset_path) as file:
demos = file['data']
for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"):
demo = demos[f'demo_{i}']
episode = _data_to_obs(
raw_obs=demo['obs'],
raw_actions=demo['actions'][:].astype(np.float32),
obs_keys=obs_keys,
abs_action=abs_action,
rotation_transformer=rotation_transformer)
replay_buffer.add_episode(episode)
val_mask = get_val_mask(
n_episodes=replay_buffer.n_episodes,
val_ratio=val_ratio,
seed=seed)
train_mask = ~val_mask
train_mask = downsample_mask(
mask=train_mask,
max_n=max_train_episodes,
seed=seed)
sampler = SequenceSampler(
replay_buffer=replay_buffer,
sequence_length=horizon,
pad_before=pad_before,
pad_after=pad_after,
episode_mask=train_mask)
self.replay_buffer = replay_buffer
self.sampler = sampler
self.abs_action = abs_action
self.train_mask = train_mask
self.horizon = horizon
self.pad_before = pad_before
self.pad_after = pad_after
self.use_legacy_normalizer = use_legacy_normalizer
def get_validation_dataset(self):
val_set = copy.copy(self)
val_set.sampler = SequenceSampler(
replay_buffer=self.replay_buffer,
sequence_length=self.horizon,
pad_before=self.pad_before,
pad_after=self.pad_after,
episode_mask=~self.train_mask
)
val_set.train_mask = ~self.train_mask
return val_set
def get_normalizer(self, **kwargs) -> LinearNormalizer:
normalizer = LinearNormalizer()
# action
stat = array_to_stats(self.replay_buffer['action'])
if self.abs_action:
if stat['mean'].shape[-1] > 10:
# dual arm
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
else:
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
if self.use_legacy_normalizer:
this_normalizer = normalizer_from_stat(stat)
else:
# already normalized
this_normalizer = get_identity_normalizer_from_stat(stat)
normalizer['action'] = this_normalizer
# aggregate obs stats
obs_stat = array_to_stats(self.replay_buffer['obs'])
normalizer['obs'] = normalizer_from_stat(obs_stat)
return normalizer
def get_all_actions(self) -> torch.Tensor:
return torch.from_numpy(self.replay_buffer['action'])
def __len__(self):
return len(self.sampler)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
data = self.sampler.sample_sequence(idx)
torch_data = dict_apply(data, torch.from_numpy)
return torch_data
def normalizer_from_stat(stat):
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
scale = np.full_like(stat['max'], fill_value=1/max_abs)
offset = np.zeros_like(stat['max'])
return SingleFieldLinearNormalizer.create_manual(
scale=scale,
offset=offset,
input_stats_dict=stat
)
def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
obs = np.concatenate([
raw_obs[key] for key in obs_keys
], axis=-1).astype(np.float32)
if abs_action:
is_dual_arm = False
if raw_actions.shape[-1] == 14:
# dual arm
raw_actions = raw_actions.reshape(-1,2,7)
is_dual_arm = True
pos = raw_actions[...,:3]
rot = raw_actions[...,3:6]
gripper = raw_actions[...,6:]
rot = rotation_transformer.forward(rot)
raw_actions = np.concatenate([
pos, rot, gripper
], axis=-1).astype(np.float32)
if is_dual_arm:
raw_actions = raw_actions.reshape(-1,20)
data = {
'obs': obs,
'action': raw_actions
}
return data
|