File size: 5,439 Bytes
1501ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from typing import List, Optional
from matplotlib.pyplot import fill
import numpy as np
import gym
from gym import spaces
from omegaconf import OmegaConf
from robomimic.envs.env_robosuite import EnvRobosuite
class RobomimicImageWrapper(gym.Env):
def __init__(self,
env: EnvRobosuite,
shape_meta: dict,
init_state: Optional[np.ndarray]=None,
render_obs_key='agentview_image',
):
self.env = env
self.render_obs_key = render_obs_key
self.init_state = init_state
self.seed_state_map = dict()
self._seed = None
self.shape_meta = shape_meta
self.render_cache = None
self.has_reset_before = False
# setup spaces
action_shape = shape_meta['action']['shape']
action_space = spaces.Box(
low=-1,
high=1,
shape=action_shape,
dtype=np.float32
)
self.action_space = action_space
observation_space = spaces.Dict()
for key, value in shape_meta['obs'].items():
shape = value['shape']
min_value, max_value = -1, 1
if key.endswith('image'):
min_value, max_value = 0, 1
elif key.endswith('depth'):
min_value, max_value = 0, 1
elif key.endswith('voxels'):
min_value, max_value = 0, 1
elif key.endswith('point_cloud'):
min_value, max_value = -10, 10
elif key.endswith('quat'):
min_value, max_value = -1, 1
elif key.endswith('qpos'):
min_value, max_value = -1, 1
elif key.endswith('pos'):
# better range?
min_value, max_value = -1, 1
else:
raise RuntimeError(f"Unsupported type {key}")
this_space = spaces.Box(
low=min_value,
high=max_value,
shape=shape,
dtype=np.float32
)
observation_space[key] = this_space
self.observation_space = observation_space
def get_observation(self, raw_obs=None):
if raw_obs is None:
raw_obs = self.env.get_observation()
self.render_cache = raw_obs[self.render_obs_key]
obs = dict()
for key in self.observation_space.keys():
obs[key] = raw_obs[key]
return obs
def seed(self, seed=None):
np.random.seed(seed=seed)
self._seed = seed
def reset(self):
if self.init_state is not None:
if not self.has_reset_before:
# the env must be fully reset at least once to ensure correct rendering
self.env.reset()
self.has_reset_before = True
# always reset to the same state
# to be compatible with gym
raw_obs = self.env.reset_to({'states': self.init_state})
elif self._seed is not None:
# reset to a specific seed
seed = self._seed
if seed in self.seed_state_map:
# env.reset is expensive, use cache
raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]})
else:
# robosuite's initializes all use numpy global random state
np.random.seed(seed=seed)
raw_obs = self.env.reset()
state = self.env.get_state()['states']
self.seed_state_map[seed] = state
self._seed = None
else:
# random reset
raw_obs = self.env.reset()
# return obs
obs = self.get_observation(raw_obs)
return obs
def step(self, action):
raw_obs, reward, done, info = self.env.step(action)
obs = self.get_observation(raw_obs)
return obs, reward, done, info
def render(self, mode='rgb_array'):
if self.render_cache is None:
raise RuntimeError('Must run reset or step before render.')
img = np.moveaxis(self.render_cache, 0, -1)
img = (img * 255).astype(np.uint8)
return img
def test():
import os
from omegaconf import OmegaConf
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
cfg = OmegaConf.load(cfg_path)
shape_meta = cfg['shape_meta']
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
from matplotlib import pyplot as plt
dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5')
env_meta = FileUtils.get_env_metadata_from_dataset(
dataset_path)
env = EnvUtils.create_env_from_metadata(
env_meta=env_meta,
render=False,
render_offscreen=False,
use_image_obs=True,
)
wrapper = RobomimicImageWrapper(
env=env,
shape_meta=shape_meta
)
wrapper.seed(0)
obs = wrapper.reset()
img = wrapper.render()
plt.imshow(img)
# states = list()
# for _ in range(2):
# wrapper.seed(0)
# wrapper.reset()
# states.append(wrapper.env.get_state()['states'])
# assert np.allclose(states[0], states[1])
# img = wrapper.render()
# plt.imshow(img)
# wrapper.seed()
# states.append(wrapper.env.get_state()['states'])
|