File size: 4,980 Bytes
c1f1d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gym
from gym import spaces
import numpy as np
from collections import defaultdict, deque
import dill
def stack_repeated(x, n):
return np.repeat(np.expand_dims(x,axis=0),n,axis=0)
def repeated_box(box_space, n):
return spaces.Box(
low=stack_repeated(box_space.low, n),
high=stack_repeated(box_space.high, n),
shape=(n,) + box_space.shape,
dtype=box_space.dtype
)
def repeated_space(space, n):
if isinstance(space, spaces.Box):
return repeated_box(space, n)
elif isinstance(space, spaces.Dict):
result_space = spaces.Dict()
for key, value in space.items():
result_space[key] = repeated_space(value, n)
return result_space
else:
raise RuntimeError(f'Unsupported space type {type(space)}')
def take_last_n(x, n):
x = list(x)
n = min(len(x), n)
return np.array(x[-n:])
def dict_take_last_n(x, n):
result = dict()
for key, value in x.items():
result[key] = take_last_n(value, n)
return result
def aggregate(data, method='max'):
if method == 'max':
# equivalent to any
return np.max(data)
elif method == 'min':
# equivalent to all
return np.min(data)
elif method == 'mean':
return np.mean(data)
elif method == 'sum':
return np.sum(data)
else:
raise NotImplementedError()
def stack_last_n_obs(all_obs, n_steps):
assert(len(all_obs) > 0)
all_obs = list(all_obs)
result = np.zeros((n_steps,) + all_obs[-1].shape,
dtype=all_obs[-1].dtype)
start_idx = -min(n_steps, len(all_obs))
result[start_idx:] = np.array(all_obs[start_idx:])
if n_steps > len(all_obs):
# pad
result[:start_idx] = result[start_idx]
return result
class MultiStepWrapper(gym.Wrapper):
def __init__(self,
env,
n_obs_steps,
n_action_steps,
max_episode_steps=None,
reward_agg_method='max'
):
super().__init__(env)
self._action_space = repeated_space(env.action_space, n_action_steps)
self._observation_space = repeated_space(env.observation_space, n_obs_steps)
self.max_episode_steps = max_episode_steps
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.reward_agg_method = reward_agg_method
self.n_obs_steps = n_obs_steps
self.obs = deque(maxlen=n_obs_steps+1)
self.reward = list()
self.done = list()
self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1))
def reset(self):
"""Resets the environment using kwargs."""
obs = super().reset()
self.obs = deque([obs], maxlen=self.n_obs_steps+1)
self.reward = list()
self.done = list()
self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1))
obs = self._get_obs(self.n_obs_steps)
return obs
def step(self, action):
"""
actions: (n_action_steps,) + action_shape
"""
for act in action:
if len(self.done) > 0 and self.done[-1]:
# termination
break
observation, reward, done, info = super().step(act)
self.obs.append(observation)
self.reward.append(reward)
if (self.max_episode_steps is not None) \
and (len(self.reward) >= self.max_episode_steps):
# truncation
done = True
self.done.append(done)
self._add_info(info)
observation = self._get_obs(self.n_obs_steps)
reward = aggregate(self.reward, self.reward_agg_method)
done = aggregate(self.done, 'max')
info = dict_take_last_n(self.info, self.n_obs_steps)
return observation, reward, done, info
def _get_obs(self, n_steps=1):
"""
Output (n_steps,) + obs_shape
"""
assert(len(self.obs) > 0)
if isinstance(self.observation_space, spaces.Box):
return stack_last_n_obs(self.obs, n_steps)
elif isinstance(self.observation_space, spaces.Dict):
result = dict()
for key in self.observation_space.keys():
result[key] = stack_last_n_obs(
[obs[key] for obs in self.obs],
n_steps
)
return result
else:
raise RuntimeError('Unsupported space type')
def _add_info(self, info):
for key, value in info.items():
self.info[key].append(value)
def get_rewards(self):
return self.reward
def get_attr(self, name):
return getattr(self, name)
def run_dill_function(self, dill_fn):
fn = dill.loads(dill_fn)
return fn(self)
def get_infos(self):
result = dict()
for k, v in self.info.items():
result[k] = list(v)
return result
|