Lillianwei's picture
mimicgen
c1f1d32
import gym
import numpy as np
class VideoWrapper(gym.Wrapper):
def __init__(self,
env,
mode='rgb_array',
enabled=True,
steps_per_render=1,
**kwargs
):
super().__init__(env)
self.mode = mode
self.enabled = enabled
self.render_kwargs = kwargs
self.steps_per_render = steps_per_render
self.frames = list()
self.step_count = 0
def reset(self, **kwargs):
obs = super().reset(**kwargs)
self.frames = list()
self.step_count = 1
if self.enabled:
frame = self.env.render(
mode=self.mode, **self.render_kwargs)
assert frame.dtype == np.uint8
self.frames.append(frame)
return obs
def step(self, action):
result = super().step(action)
self.step_count += 1
if self.enabled and ((self.step_count % self.steps_per_render) == 0):
frame = self.env.render(
mode=self.mode, **self.render_kwargs)
assert frame.dtype == np.uint8
self.frames.append(frame)
return result
def render(self, mode='rgb_array', **kwargs):
return self.frames