File size: 3,208 Bytes
1501ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from equi_diffpo.dataset.base_dataset import LinearNormalizer
from equi_diffpo.model.common.normalizer import LinearNormalizer
from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat
from equi_diffpo.common.normalize_util import (
    robomimic_abs_action_only_symmetric_normalizer_from_stat,
    get_range_normalizer_from_stat,
    get_range_symmetric_normalizer_from_stat,
    get_image_range_normalizer,
    get_identity_normalizer_from_stat,
    array_to_stats
)
import numpy as np

class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset):
    def __init__(self,
            shape_meta: dict,
            dataset_path: str,
            horizon=1,
            pad_before=0,
            pad_after=0,
            n_obs_steps=None,
            abs_action=False,
            rotation_rep='rotation_6d', # ignored when abs_action=False
            use_legacy_normalizer=False,
            use_cache=False,
            seed=42,
            val_ratio=0.0,
            n_demo=100
        ):
        super().__init__(
            shape_meta,
            dataset_path,
            horizon,
            pad_before,
            pad_after,
            n_obs_steps,
            abs_action,
            rotation_rep,
            use_legacy_normalizer,
            use_cache,
            seed,
            val_ratio,
            n_demo
        )
        
    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        stat = array_to_stats(self.replay_buffer['action'])
        if self.abs_action:
            if stat['mean'].shape[-1] > 10:
                # dual arm
                raise NotImplementedError
            else:
                this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
            
            if self.use_legacy_normalizer:
                this_normalizer = normalizer_from_stat(stat)
        else:
            # already normalized
            this_normalizer = get_identity_normalizer_from_stat(stat)
        normalizer['action'] = this_normalizer

        # obs
        for key in self.lowdim_keys:
            stat = array_to_stats(self.replay_buffer[key])

            if key.endswith('qpos'):
                this_normalizer = get_range_normalizer_from_stat(stat)
            elif key.endswith('pos'):
                this_normalizer = get_range_symmetric_normalizer_from_stat(stat)
            elif key.endswith('quat'):
                # quaternion is in [-1,1] already
                this_normalizer = get_identity_normalizer_from_stat(stat)
            elif key.find('bbox') > -1:
                this_normalizer = get_identity_normalizer_from_stat(stat)
            else:
                raise RuntimeError('unsupported')
            normalizer[key] = this_normalizer

        # image
        for key in self.rgb_keys:
            normalizer[key] = get_image_range_normalizer()
        
        normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)})
        normalizer['crops'] = get_image_range_normalizer()

        return normalizer