File size: 4,045 Bytes
1501ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import List, Dict, Optional
import numpy as np
import gym
from gym.spaces import Box
from robomimic.envs.env_robosuite import EnvRobosuite
class RobomimicLowdimWrapper(gym.Env):
def __init__(self,
env: EnvRobosuite,
obs_keys: List[str]=[
'object',
'robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos'],
init_state: Optional[np.ndarray]=None,
render_hw=(256,256),
render_camera_name='agentview'
):
self.env = env
self.obs_keys = obs_keys
self.init_state = init_state
self.render_hw = render_hw
self.render_camera_name = render_camera_name
self.seed_state_map = dict()
self._seed = None
# setup spaces
low = np.full(env.action_dimension, fill_value=-1)
high = np.full(env.action_dimension, fill_value=1)
self.action_space = Box(
low=low,
high=high,
shape=low.shape,
dtype=low.dtype
)
obs_example = self.get_observation()
low = np.full_like(obs_example, fill_value=-1)
high = np.full_like(obs_example, fill_value=1)
self.observation_space = Box(
low=low,
high=high,
shape=low.shape,
dtype=low.dtype
)
def get_observation(self):
raw_obs = self.env.get_observation()
obs = np.concatenate([
raw_obs[key] for key in self.obs_keys
], axis=0)
return obs
def seed(self, seed=None):
np.random.seed(seed=seed)
self._seed = seed
def reset(self):
if self.init_state is not None:
# always reset to the same state
# to be compatible with gym
self.env.reset_to({'states': self.init_state})
elif self._seed is not None:
# reset to a specific seed
seed = self._seed
if seed in self.seed_state_map:
# env.reset is expensive, use cache
self.env.reset_to({'states': self.seed_state_map[seed]})
else:
# robosuite's initializes all use numpy global random state
np.random.seed(seed=seed)
self.env.reset()
state = self.env.get_state()['states']
self.seed_state_map[seed] = state
self._seed = None
else:
# random reset
self.env.reset()
# return obs
obs = self.get_observation()
return obs
def step(self, action):
raw_obs, reward, done, info = self.env.step(action)
obs = np.concatenate([
raw_obs[key] for key in self.obs_keys
], axis=0)
return obs, reward, done, info
def render(self, mode='rgb_array'):
h, w = self.render_hw
return self.env.render(mode=mode,
height=h, width=w,
camera_name=self.render_camera_name)
def test():
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
from matplotlib import pyplot as plt
dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5'
env_meta = FileUtils.get_env_metadata_from_dataset(
dataset_path)
env = EnvUtils.create_env_from_metadata(
env_meta=env_meta,
render=False,
render_offscreen=False,
use_image_obs=False,
)
wrapper = RobomimicLowdimWrapper(
env=env,
obs_keys=[
'object',
'robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos'
]
)
states = list()
for _ in range(2):
wrapper.seed(0)
wrapper.reset()
states.append(wrapper.env.get_state()['states'])
assert np.allclose(states[0], states[1])
img = wrapper.render()
plt.imshow(img)
# wrapper.seed()
# states.append(wrapper.env.get_state()['states'])
|