Skip to content

Commit

Permalink
ENH: implement RecordVideoV0 (openai#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Jan 17, 2023
1 parent dcb150e commit 019a593
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 4 deletions.
199 changes: 196 additions & 3 deletions gymnasium/experimental/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
"""
from __future__ import annotations

import os
from copy import deepcopy
from typing import Any, SupportsFloat
from typing import Any, Callable, List, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium import error, logger
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled

Expand Down Expand Up @@ -79,9 +81,200 @@ def render(self) -> RenderFrame | list[RenderFrame] | None:


class RecordVideoV0(gym.Wrapper):
"""Record a video of an environment."""
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
To do this, you can specify ``episode_trigger`` or ``step_trigger``.
They should be functions returning a boolean that indicates whether a recording should be started at the
current episode or step, respectively.
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed,
i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and
then every 1000 episodes.
By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed
length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``.
This wrapper uses the value `fps` from metadata as the number of frames per second;
if `fps` is not defined in metadata, the default value 30 is used.
"""

def __init__(
self,
env: gym.Env,
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: int = 0,
name_prefix: str = "rl-video",
disable_logger: bool = False,
):
"""Wrapper records videos of rollouts.
Args:
env: The environment that will be wrapped
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not
"""
super().__init__(env)
try:
import moviepy # noqa: F401
except ImportError as e:
raise error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
) from e

if env.render_mode in {None, "human", "ansi"}:
raise ValueError(
f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.",
"Initialize your environment with a render_mode that returns an image, such as rgb_array.",
)

if episode_trigger is None and step_trigger is None:

def capped_cubic_video_schedule(episode_id: int) -> bool:
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0

episode_trigger = capped_cubic_video_schedule

self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.disable_logger = disable_logger

self.video_folder = os.path.abspath(video_folder)
if os.path.isdir(self.video_folder):
logger.warn(
f"Overwriting existing videos at {self.video_folder} folder "
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)

self.name_prefix = name_prefix
self._video_name = None
self.frames_per_sec = self.metadata.get("render_fps", 30)
self.video_length = video_length if video_length != 0 else float("inf")
self.recording = False
self.recorded_frames = []
self.render_history = []

self.step_id = -1
self.episode_id = -1

def _capture_frame(self):
assert self.recording, "Cannot capture a frame, recording wasn't started."

frame = self.env.render()
if isinstance(frame, List):
if len(frame) == 0: # render was called
return
self.render_history += frame
frame = frame[-1]

if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
else:
self.stop_recording()
logger.warn(
"Recording stopped: expected type of frame returned by render ",
f"to be a numpy array, got instead {type(frame)}.",
)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment and eventually starts a new recording."""
obs, info = super().reset(seed=seed, options=options)
self.episode_id += 1

if self.recording and self.video_length == float("inf"):
self.stop_recording()

if self.episode_trigger and self.episode_trigger(self.episode_id):
self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}")
if self.recording:
self._capture_frame()
if len(self.recorded_frames) > self.video_length:
self.stop_recording()

return obs, info

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
obs, rew, terminated, truncated, info = self.env.step(action)
self.step_id += 1

if self.step_trigger and self.step_trigger(self.step_id):
self.start_recording(f"{self.name_prefix}-step-{self.step_id}")
if self.recording:
self._capture_frame()

if len(self.recorded_frames) > self.video_length:
self.stop_recording()

return obs, rew, terminated, truncated, info

def start_recording(self, video_name):
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
self.stop_recording()

self.recording = True
self._video_name = video_name

def stop_recording(self):
"""Stop current recording and saves the video."""
assert self.recording, "stop_recording was called, but no recording was started"

if len(self.recorded_frames) == 0:
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
raise error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
) from e

clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
moviepy_logger = None if self.disable_logger else "bar"
path = os.path.join(self.video_folder, f"{self._video_name}.mp4")
clip.write_videofile(path, logger=moviepy_logger)

self.recorded_frames = []
self.recording = False
self._video_name = None

def render(self):
"""Compute the render frames as specified by render_mode attribute during initialization of the environment."""
render_out = super().render()
if self.recording and isinstance(render_out, List):
self.recorded_frames += render_out

if len(self.render_history) > 0:
tmp_history = self.render_history
self.render_history = []
return tmp_history + render_out
else:
return render_out

def close(self):
"""Closes the wrapper then the video recorder."""
super().close()
if self.recording:
self.stop_recording()

pass
def __del__(self):
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0:
logger.warn("Unable to save last video! Did you call close()?")


class HumanRenderingV0(gym.Wrapper):
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/wrappers/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
Args:
episode_id: The episode number
Expand Down
165 changes: 165 additions & 0 deletions tests/experimental/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
@@ -1 +1,166 @@
"""Test suite for RecordVideoV0."""
import os
import shutil
from typing import List

import gymnasium as gym
from gymnasium.experimental.wrappers import RecordVideoV0


def test_record_video_using_default_trigger():
"""Test RecordVideo using the default episode trigger."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env = RecordVideoV0(env, "videos")
env.reset()
episode_count = 0
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
episode_count += 1

env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)
shutil.rmtree("videos")


def test_record_video_while_rendering():
"""Test RecordVideo while calling render and using a _list render mode."""
env = gym.make("FrozenLake-v1", render_mode="rgb_array_list")
env = RecordVideoV0(env, "videos")
env.reset()
episode_count = 0
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
env.render()
if terminated or truncated:
env.reset()
episode_count += 1

env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)
shutil.rmtree("videos")


def test_record_video_step_trigger():
"""Test RecordVideo defining step trigger function."""
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env._max_episode_steps = 20
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0)
env.reset()
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 2


def test_record_video_both_trigger():
"""Test RecordVideo defining both step and episode trigger functions."""
env = gym.make(
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
)
env._max_episode_steps = 20
env = RecordVideoV0(
env,
"videos",
step_trigger=lambda x: x == 100,
episode_trigger=lambda x: x == 0 or x == 3,
)
env.reset()
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 3


def test_record_video_length():
"""Test if argument video_length of RecordVideo works properly."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env._max_episode_steps = 20
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x == 0, video_length=10)
env.reset()
for _ in range(10):
action = env.action_space.sample()
env.step(action)

assert env.recording
action = env.action_space.sample()
env.step(action)
assert not env.recording
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 1
shutil.rmtree("videos")


def test_rendering_works():
"""Test if render output is as expected when the env is wrapped with RecordVideo."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env._max_episode_steps = 20
env = RecordVideoV0(env, "videos")
env.reset()
n_steps = 10
for _ in range(n_steps):
action = env.action_space.sample()
env.step(action)

render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == n_steps + 1
render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == 0
env.close()
shutil.rmtree("videos")


def make_env(gym_id, idx, **kwargs):
"""Utility function to make an env and wrap it with RecordVideo only the first time."""

def thunk():
env = gym.make(gym_id, disable_env_checker=True, **kwargs)
env._max_episode_steps = 20
if idx == 0:
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0)
return env

return thunk


def test_record_video_within_vector():
"""Test RecordVideo used as env of SyncVectorEnv."""
envs = gym.vector.SyncVectorEnv(
[make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(2)]
)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs.reset()
for i in range(199):
_, _, _, _, infos = envs.step(envs.action_space.sample())

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 2
shutil.rmtree("videos")

0 comments on commit 019a593

Please sign in to comment.