Equidiff / equidiff /equi_diffpo /dataset /robomimic_replay_lowdim_sym_dataset.py
Lillianwei's picture
update
1501ed7
from typing import Dict, List
import torch
import numpy as np
from equi_diffpo.common.pytorch_util import dict_apply
from equi_diffpo.dataset.base_dataset import LinearNormalizer
from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat
from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat
from equi_diffpo.common.normalize_util import (
robomimic_abs_action_only_symmetric_normalizer_from_stat,
get_identity_normalizer_from_stat,
array_to_stats
)
class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset):
def __init__(self,
dataset_path: str,
horizon=1,
pad_before=0,
pad_after=0,
obs_keys: List[str]=[
'object',
'robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos'],
abs_action=False,
rotation_rep='rotation_6d',
use_legacy_normalizer=False,
seed=42,
val_ratio=0.0,
max_train_episodes=None,
n_demo=100
):
super().__init__(
dataset_path,
horizon,
pad_before,
pad_after,
obs_keys,
abs_action,
rotation_rep,
use_legacy_normalizer,
seed,
val_ratio,
max_train_episodes,
n_demo,
)
def get_normalizer(self, **kwargs) -> LinearNormalizer:
normalizer = LinearNormalizer()
# action
stat = array_to_stats(self.replay_buffer['action'])
if self.abs_action:
if stat['mean'].shape[-1] > 10:
# dual arm
raise NotImplementedError
else:
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
if self.use_legacy_normalizer:
this_normalizer = normalizer_from_stat(stat)
else:
# already normalized
this_normalizer = get_identity_normalizer_from_stat(stat)
normalizer['action'] = this_normalizer
# aggregate obs stats
obs_stat = array_to_stats(self.replay_buffer['obs'])
normalizer['obs'] = normalizer_from_stat(obs_stat)
return normalizer