from typing import List, Optional from matplotlib.pyplot import fill import numpy as np import gym from gym import spaces from omegaconf import OmegaConf from robomimic.envs.env_robosuite import EnvRobosuite class RobomimicImageWrapper(gym.Env): def __init__(self, env: EnvRobosuite, shape_meta: dict, init_state: Optional[np.ndarray]=None, render_obs_key='agentview_image', ): self.env = env self.render_obs_key = render_obs_key self.init_state = init_state self.seed_state_map = dict() self._seed = None self.shape_meta = shape_meta self.render_cache = None self.has_reset_before = False # setup spaces action_shape = shape_meta['action']['shape'] action_space = spaces.Box( low=-1, high=1, shape=action_shape, dtype=np.float32 ) self.action_space = action_space observation_space = spaces.Dict() for key, value in shape_meta['obs'].items(): shape = value['shape'] min_value, max_value = -1, 1 if key.endswith('image'): min_value, max_value = 0, 1 elif key.endswith('depth'): min_value, max_value = 0, 1 elif key.endswith('voxels'): min_value, max_value = 0, 1 elif key.endswith('point_cloud'): min_value, max_value = -10, 10 elif key.endswith('quat'): min_value, max_value = -1, 1 elif key.endswith('qpos'): min_value, max_value = -1, 1 elif key.endswith('pos'): # better range? min_value, max_value = -1, 1 else: raise RuntimeError(f"Unsupported type {key}") this_space = spaces.Box( low=min_value, high=max_value, shape=shape, dtype=np.float32 ) observation_space[key] = this_space self.observation_space = observation_space def get_observation(self, raw_obs=None): if raw_obs is None: raw_obs = self.env.get_observation() self.render_cache = raw_obs[self.render_obs_key] obs = dict() for key in self.observation_space.keys(): obs[key] = raw_obs[key] return obs def seed(self, seed=None): np.random.seed(seed=seed) self._seed = seed def reset(self): if self.init_state is not None: if not self.has_reset_before: # the env must be fully reset at least once to ensure correct rendering self.env.reset() self.has_reset_before = True # always reset to the same state # to be compatible with gym raw_obs = self.env.reset_to({'states': self.init_state}) elif self._seed is not None: # reset to a specific seed seed = self._seed if seed in self.seed_state_map: # env.reset is expensive, use cache raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]}) else: # robosuite's initializes all use numpy global random state np.random.seed(seed=seed) raw_obs = self.env.reset() state = self.env.get_state()['states'] self.seed_state_map[seed] = state self._seed = None else: # random reset raw_obs = self.env.reset() # return obs obs = self.get_observation(raw_obs) return obs def step(self, action): raw_obs, reward, done, info = self.env.step(action) obs = self.get_observation(raw_obs) return obs, reward, done, info def render(self, mode='rgb_array'): if self.render_cache is None: raise RuntimeError('Must run reset or step before render.') img = np.moveaxis(self.render_cache, 0, -1) img = (img * 255).astype(np.uint8) return img def test(): import os from omegaconf import OmegaConf cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml') cfg = OmegaConf.load(cfg_path) shape_meta = cfg['shape_meta'] import robomimic.utils.file_utils as FileUtils import robomimic.utils.env_utils as EnvUtils from matplotlib import pyplot as plt dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5') env_meta = FileUtils.get_env_metadata_from_dataset( dataset_path) env = EnvUtils.create_env_from_metadata( env_meta=env_meta, render=False, render_offscreen=False, use_image_obs=True, ) wrapper = RobomimicImageWrapper( env=env, shape_meta=shape_meta ) wrapper.seed(0) obs = wrapper.reset() img = wrapper.render() plt.imshow(img) # states = list() # for _ in range(2): # wrapper.seed(0) # wrapper.reset() # states.append(wrapper.env.get_state()['states']) # assert np.allclose(states[0], states[1]) # img = wrapper.render() # plt.imshow(img) # wrapper.seed() # states.append(wrapper.env.get_state()['states'])