import gym import numpy as np class VideoWrapper(gym.Wrapper): def __init__(self, env, mode='rgb_array', enabled=True, steps_per_render=1, **kwargs ): super().__init__(env) self.mode = mode self.enabled = enabled self.render_kwargs = kwargs self.steps_per_render = steps_per_render self.frames = list() self.step_count = 0 def reset(self, **kwargs): obs = super().reset(**kwargs) self.frames = list() self.step_count = 1 if self.enabled: frame = self.env.render( mode=self.mode, **self.render_kwargs) assert frame.dtype == np.uint8 self.frames.append(frame) return obs def step(self, action): result = super().step(action) self.step_count += 1 if self.enabled and ((self.step_count % self.steps_per_render) == 0): frame = self.env.render( mode=self.mode, **self.render_kwargs) assert frame.dtype == np.uint8 self.frames.append(frame) return result def render(self, mode='rgb_array', **kwargs): return self.frames