import gym from gym import spaces import numpy as np from collections import defaultdict, deque import dill def stack_repeated(x, n): return np.repeat(np.expand_dims(x,axis=0),n,axis=0) def repeated_box(box_space, n): return spaces.Box( low=stack_repeated(box_space.low, n), high=stack_repeated(box_space.high, n), shape=(n,) + box_space.shape, dtype=box_space.dtype ) def repeated_space(space, n): if isinstance(space, spaces.Box): return repeated_box(space, n) elif isinstance(space, spaces.Dict): result_space = spaces.Dict() for key, value in space.items(): result_space[key] = repeated_space(value, n) return result_space else: raise RuntimeError(f'Unsupported space type {type(space)}') def take_last_n(x, n): x = list(x) n = min(len(x), n) return np.array(x[-n:]) def dict_take_last_n(x, n): result = dict() for key, value in x.items(): result[key] = take_last_n(value, n) return result def aggregate(data, method='max'): if method == 'max': # equivalent to any return np.max(data) elif method == 'min': # equivalent to all return np.min(data) elif method == 'mean': return np.mean(data) elif method == 'sum': return np.sum(data) else: raise NotImplementedError() def stack_last_n_obs(all_obs, n_steps): assert(len(all_obs) > 0) all_obs = list(all_obs) result = np.zeros((n_steps,) + all_obs[-1].shape, dtype=all_obs[-1].dtype) start_idx = -min(n_steps, len(all_obs)) result[start_idx:] = np.array(all_obs[start_idx:]) if n_steps > len(all_obs): # pad result[:start_idx] = result[start_idx] return result class MultiStepWrapper(gym.Wrapper): def __init__(self, env, n_obs_steps, n_action_steps, max_episode_steps=None, reward_agg_method='max' ): super().__init__(env) self._action_space = repeated_space(env.action_space, n_action_steps) self._observation_space = repeated_space(env.observation_space, n_obs_steps) self.max_episode_steps = max_episode_steps self.n_obs_steps = n_obs_steps self.n_action_steps = n_action_steps self.reward_agg_method = reward_agg_method self.n_obs_steps = n_obs_steps self.obs = deque(maxlen=n_obs_steps+1) self.reward = list() self.done = list() self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1)) def reset(self): """Resets the environment using kwargs.""" obs = super().reset() self.obs = deque([obs], maxlen=self.n_obs_steps+1) self.reward = list() self.done = list() self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1)) obs = self._get_obs(self.n_obs_steps) return obs def step(self, action): """ actions: (n_action_steps,) + action_shape """ for act in action: if len(self.done) > 0 and self.done[-1]: # termination break observation, reward, done, info = super().step(act) self.obs.append(observation) self.reward.append(reward) if (self.max_episode_steps is not None) \ and (len(self.reward) >= self.max_episode_steps): # truncation done = True self.done.append(done) self._add_info(info) observation = self._get_obs(self.n_obs_steps) reward = aggregate(self.reward, self.reward_agg_method) done = aggregate(self.done, 'max') info = dict_take_last_n(self.info, self.n_obs_steps) return observation, reward, done, info def _get_obs(self, n_steps=1): """ Output (n_steps,) + obs_shape """ assert(len(self.obs) > 0) if isinstance(self.observation_space, spaces.Box): return stack_last_n_obs(self.obs, n_steps) elif isinstance(self.observation_space, spaces.Dict): result = dict() for key in self.observation_space.keys(): result[key] = stack_last_n_obs( [obs[key] for obs in self.obs], n_steps ) return result else: raise RuntimeError('Unsupported space type') def _add_info(self, info): for key, value in info.items(): self.info[key].append(value) def get_rewards(self): return self.reward def get_attr(self, name): return getattr(self, name) def run_dill_function(self, dill_fn): fn = dill.loads(dill_fn) return fn(self) def get_infos(self): result = dict() for k, v in self.info.items(): result[k] = list(v) return result