|
from equi_diffpo.dataset.base_dataset import LinearNormalizer |
|
from equi_diffpo.model.common.normalizer import LinearNormalizer |
|
from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat |
|
from equi_diffpo.common.normalize_util import ( |
|
robomimic_abs_action_only_symmetric_normalizer_from_stat, |
|
get_range_normalizer_from_stat, |
|
get_range_symmetric_normalizer_from_stat, |
|
get_image_range_normalizer, |
|
get_identity_normalizer_from_stat, |
|
array_to_stats |
|
) |
|
import numpy as np |
|
|
|
class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset): |
|
def __init__(self, |
|
shape_meta: dict, |
|
dataset_path: str, |
|
horizon=1, |
|
pad_before=0, |
|
pad_after=0, |
|
n_obs_steps=None, |
|
abs_action=False, |
|
rotation_rep='rotation_6d', |
|
use_legacy_normalizer=False, |
|
use_cache=False, |
|
seed=42, |
|
val_ratio=0.0, |
|
n_demo=100 |
|
): |
|
super().__init__( |
|
shape_meta, |
|
dataset_path, |
|
horizon, |
|
pad_before, |
|
pad_after, |
|
n_obs_steps, |
|
abs_action, |
|
rotation_rep, |
|
use_legacy_normalizer, |
|
use_cache, |
|
seed, |
|
val_ratio, |
|
n_demo |
|
) |
|
|
|
def get_normalizer(self, **kwargs) -> LinearNormalizer: |
|
normalizer = LinearNormalizer() |
|
|
|
|
|
stat = array_to_stats(self.replay_buffer['action']) |
|
if self.abs_action: |
|
if stat['mean'].shape[-1] > 10: |
|
|
|
raise NotImplementedError |
|
else: |
|
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat) |
|
|
|
if self.use_legacy_normalizer: |
|
this_normalizer = normalizer_from_stat(stat) |
|
else: |
|
|
|
this_normalizer = get_identity_normalizer_from_stat(stat) |
|
normalizer['action'] = this_normalizer |
|
|
|
|
|
for key in self.lowdim_keys: |
|
stat = array_to_stats(self.replay_buffer[key]) |
|
|
|
if key.endswith('qpos'): |
|
this_normalizer = get_range_normalizer_from_stat(stat) |
|
elif key.endswith('pos'): |
|
this_normalizer = get_range_symmetric_normalizer_from_stat(stat) |
|
elif key.endswith('quat'): |
|
|
|
this_normalizer = get_identity_normalizer_from_stat(stat) |
|
elif key.find('bbox') > -1: |
|
this_normalizer = get_identity_normalizer_from_stat(stat) |
|
else: |
|
raise RuntimeError('unsupported') |
|
normalizer[key] = this_normalizer |
|
|
|
|
|
for key in self.rgb_keys: |
|
normalizer[key] = get_image_range_normalizer() |
|
|
|
normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)}) |
|
normalizer['crops'] = get_image_range_normalizer() |
|
|
|
return normalizer |
|
|
|
|