File size: 3,208 Bytes
1501ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from equi_diffpo.dataset.base_dataset import LinearNormalizer
from equi_diffpo.model.common.normalizer import LinearNormalizer
from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat
from equi_diffpo.common.normalize_util import (
robomimic_abs_action_only_symmetric_normalizer_from_stat,
get_range_normalizer_from_stat,
get_range_symmetric_normalizer_from_stat,
get_image_range_normalizer,
get_identity_normalizer_from_stat,
array_to_stats
)
import numpy as np
class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset):
def __init__(self,
shape_meta: dict,
dataset_path: str,
horizon=1,
pad_before=0,
pad_after=0,
n_obs_steps=None,
abs_action=False,
rotation_rep='rotation_6d', # ignored when abs_action=False
use_legacy_normalizer=False,
use_cache=False,
seed=42,
val_ratio=0.0,
n_demo=100
):
super().__init__(
shape_meta,
dataset_path,
horizon,
pad_before,
pad_after,
n_obs_steps,
abs_action,
rotation_rep,
use_legacy_normalizer,
use_cache,
seed,
val_ratio,
n_demo
)
def get_normalizer(self, **kwargs) -> LinearNormalizer:
normalizer = LinearNormalizer()
# action
stat = array_to_stats(self.replay_buffer['action'])
if self.abs_action:
if stat['mean'].shape[-1] > 10:
# dual arm
raise NotImplementedError
else:
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
if self.use_legacy_normalizer:
this_normalizer = normalizer_from_stat(stat)
else:
# already normalized
this_normalizer = get_identity_normalizer_from_stat(stat)
normalizer['action'] = this_normalizer
# obs
for key in self.lowdim_keys:
stat = array_to_stats(self.replay_buffer[key])
if key.endswith('qpos'):
this_normalizer = get_range_normalizer_from_stat(stat)
elif key.endswith('pos'):
this_normalizer = get_range_symmetric_normalizer_from_stat(stat)
elif key.endswith('quat'):
# quaternion is in [-1,1] already
this_normalizer = get_identity_normalizer_from_stat(stat)
elif key.find('bbox') > -1:
this_normalizer = get_identity_normalizer_from_stat(stat)
else:
raise RuntimeError('unsupported')
normalizer[key] = this_normalizer
# image
for key in self.rgb_keys:
normalizer[key] = get_image_range_normalizer()
normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)})
normalizer['crops'] = get_image_range_normalizer()
return normalizer
|