File size: 2,401 Bytes
1501ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Dict, List
import torch
import numpy as np
from equi_diffpo.common.pytorch_util import dict_apply
from equi_diffpo.dataset.base_dataset import LinearNormalizer
from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat
from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat
from equi_diffpo.common.normalize_util import (
    robomimic_abs_action_only_symmetric_normalizer_from_stat,
    get_identity_normalizer_from_stat,
    array_to_stats
)


class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset):
    def __init__(self,
            dataset_path: str,
            horizon=1,
            pad_before=0,
            pad_after=0,
            obs_keys: List[str]=[
                'object', 
                'robot0_eef_pos', 
                'robot0_eef_quat', 
                'robot0_gripper_qpos'],
            abs_action=False,
            rotation_rep='rotation_6d',
            use_legacy_normalizer=False,
            seed=42,
            val_ratio=0.0,
            max_train_episodes=None,
            n_demo=100
        ):
        super().__init__(
            dataset_path,
            horizon,
            pad_before,
            pad_after,
            obs_keys,
            abs_action,
            rotation_rep,
            use_legacy_normalizer,
            seed,
            val_ratio,
            max_train_episodes,
            n_demo,
        )

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        stat = array_to_stats(self.replay_buffer['action'])
        if self.abs_action:
            if stat['mean'].shape[-1] > 10:
                # dual arm
                raise NotImplementedError
            else:
                this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
            
            if self.use_legacy_normalizer:
                this_normalizer = normalizer_from_stat(stat)
        else:
            # already normalized
            this_normalizer = get_identity_normalizer_from_stat(stat)
        normalizer['action'] = this_normalizer
        
        # aggregate obs stats
        obs_stat = array_to_stats(self.replay_buffer['obs'])


        normalizer['obs'] = normalizer_from_stat(obs_stat)
        return normalizer