|
from typing import Dict, List |
|
import torch |
|
import numpy as np |
|
from equi_diffpo.common.pytorch_util import dict_apply |
|
from equi_diffpo.dataset.base_dataset import LinearNormalizer |
|
from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat |
|
from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat |
|
from equi_diffpo.common.normalize_util import ( |
|
robomimic_abs_action_only_symmetric_normalizer_from_stat, |
|
get_identity_normalizer_from_stat, |
|
array_to_stats |
|
) |
|
|
|
|
|
class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset): |
|
def __init__(self, |
|
dataset_path: str, |
|
horizon=1, |
|
pad_before=0, |
|
pad_after=0, |
|
obs_keys: List[str]=[ |
|
'object', |
|
'robot0_eef_pos', |
|
'robot0_eef_quat', |
|
'robot0_gripper_qpos'], |
|
abs_action=False, |
|
rotation_rep='rotation_6d', |
|
use_legacy_normalizer=False, |
|
seed=42, |
|
val_ratio=0.0, |
|
max_train_episodes=None, |
|
n_demo=100 |
|
): |
|
super().__init__( |
|
dataset_path, |
|
horizon, |
|
pad_before, |
|
pad_after, |
|
obs_keys, |
|
abs_action, |
|
rotation_rep, |
|
use_legacy_normalizer, |
|
seed, |
|
val_ratio, |
|
max_train_episodes, |
|
n_demo, |
|
) |
|
|
|
def get_normalizer(self, **kwargs) -> LinearNormalizer: |
|
normalizer = LinearNormalizer() |
|
|
|
|
|
stat = array_to_stats(self.replay_buffer['action']) |
|
if self.abs_action: |
|
if stat['mean'].shape[-1] > 10: |
|
|
|
raise NotImplementedError |
|
else: |
|
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat) |
|
|
|
if self.use_legacy_normalizer: |
|
this_normalizer = normalizer_from_stat(stat) |
|
else: |
|
|
|
this_normalizer = get_identity_normalizer_from_stat(stat) |
|
normalizer['action'] = this_normalizer |
|
|
|
|
|
obs_stat = array_to_stats(self.replay_buffer['obs']) |
|
|
|
|
|
normalizer['obs'] = normalizer_from_stat(obs_stat) |
|
return normalizer |
|
|