File size: 1,261 Bytes
c1f1d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import gym
import numpy as np
class VideoWrapper(gym.Wrapper):
def __init__(self,
env,
mode='rgb_array',
enabled=True,
steps_per_render=1,
**kwargs
):
super().__init__(env)
self.mode = mode
self.enabled = enabled
self.render_kwargs = kwargs
self.steps_per_render = steps_per_render
self.frames = list()
self.step_count = 0
def reset(self, **kwargs):
obs = super().reset(**kwargs)
self.frames = list()
self.step_count = 1
if self.enabled:
frame = self.env.render(
mode=self.mode, **self.render_kwargs)
assert frame.dtype == np.uint8
self.frames.append(frame)
return obs
def step(self, action):
result = super().step(action)
self.step_count += 1
if self.enabled and ((self.step_count % self.steps_per_render) == 0):
frame = self.env.render(
mode=self.mode, **self.render_kwargs)
assert frame.dtype == np.uint8
self.frames.append(frame)
return result
def render(self, mode='rgb_array', **kwargs):
return self.frames
|