Skip to content

Commit

Permalink
Improve RecordVideo wrapper testing (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Nov 28, 2023
1 parent 2f19f3e commit f4c302d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 56 deletions.
11 changes: 5 additions & 6 deletions gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
video_length: int = 0,
name_prefix: str = "rl-video",
fps: int | None = None,
disable_logger: bool = False,
disable_logger: bool = True,
):
"""Wrapper records videos of rollouts.
Expand All @@ -247,9 +247,9 @@ def __init__(
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
fps (int): The frame per second in the video. The default value is the one specified in the environment metadata.
If the environment metadata doesn't specify ``render_fps``, the value 30 is used.
disable_logger (bool): Whether to disable moviepy logger or not
fps (int): The frame per second in the video. Provides a custom video fps for environment, if ``None`` then
the environment metadata ``render_fps`` key is used if it exists, otherwise a default value of 30 is used.
disable_logger (bool): Whether to disable moviepy logger or not, default it is disabled
"""
gym.utils.RecordConstructorArgs.__init__(
self,
Expand Down Expand Up @@ -320,8 +320,7 @@ def _capture_frame(self):
else:
self.stop_recording()
logger.warn(
"Recording stopped: expected type of frame returned by render ",
f"to be a numpy array, got instead {type(frame)}.",
f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}."
)

def reset(
Expand Down
145 changes: 95 additions & 50 deletions tests/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,61 @@
import shutil
from typing import List

import numpy as np
import pytest

import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from gymnasium.wrappers import RecordVideo, RenderCollection


def test_record_video_using_default_trigger():
"""Test RecordVideo using the default episode trigger."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env = RecordVideo(env, "videos")
env.reset()
episode_count = 0
for _ in range(199):
def test_video_folder_and_filenames(
video_folder="custom_video_folder", name_prefix="video-prefix"
):
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env,
video_folder=video_folder,
name_prefix=name_prefix,
episode_trigger=lambda x: x in [1, 4],
step_trigger=lambda x: x in [0, 25],
)

env.reset(seed=123)
env.action_space.seed(123)
for _ in range(100):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
episode_count += 1

env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert env.episode_trigger is not None
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)
shutil.rmtree("videos")

assert os.path.isdir(video_folder)
mp4_files = {file for file in os.listdir(video_folder) if file.endswith(".mp4")}
shutil.rmtree(video_folder)
assert mp4_files == {
"video-prefix-step-0.mp4", # step triggers
"video-prefix-step-25.mp4",
"video-prefix-episode-1.mp4", # episode triggers
"video-prefix-episode-4.mp4",
}


@pytest.mark.parametrize("episodic_trigger", [None, lambda x: x in [0, 3, 5, 10, 12]])
def test_episodic_trigger(episodic_trigger):
"""Test RecordVideo using the default episode trigger."""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, "videos", episode_trigger=episodic_trigger)

def test_record_video_while_rendering():
"""Test RecordVideo while calling render and using a _list render mode."""
env = gym.make("FrozenLake-v1", render_mode="rgb_array_list")
env = RecordVideo(env, "videos")
env.reset()
episode_count = 0
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
env.render()
if terminated or truncated:
env.reset()
episode_count += 1

env.close()

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert env.episode_trigger is not None
Expand All @@ -54,10 +67,9 @@ def test_record_video_while_rendering():
shutil.rmtree("videos")


def test_record_video_step_trigger():
def test_step_trigger():
"""Test RecordVideo defining step trigger function."""
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env._max_episode_steps = 20
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
env.reset()
for _ in range(199):
Expand All @@ -72,69 +84,102 @@ def test_record_video_step_trigger():
assert len(mp4_files) == 2


def test_record_video_both_trigger():
def test_both_episodic_and_step_trigger():
"""Test RecordVideo defining both step and episode trigger functions."""
env = gym.make(
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
)
env._max_episode_steps = 20
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env,
"videos",
step_trigger=lambda x: x == 100,
episode_trigger=lambda x: x == 0 or x == 3,
)
env.reset()
for _ in range(199):
# episode reset time steps: 0, 18, 44, 55, 80, 103, 117, 143, 173, 191
# steps recorded: 0-18, 55-80, 100-103

env.reset(seed=123)
env.action_space.seed(123)
for i in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 3


def test_record_video_length():
def test_video_length(video_length: int = 10):
"""Test if argument video_length of RecordVideo works properly."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env._max_episode_steps = 20
env = RecordVideo(env, "videos", step_trigger=lambda x: x == 0, video_length=10)
env.reset()
for _ in range(10):
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env, "videos", step_trigger=lambda x: x == 0, video_length=video_length
)

env.reset(seed=123)
env.action_space.seed(123)
for _ in range(video_length):
_, _, term, trunc, _ = env.step(env.action_space.sample())
if term or trunc:
break

# check that the environment is still recording then take a step to take the number of steps > video length
assert env.recording
action = env.action_space.sample()
env.step(action)
env.step(env.action_space.sample())
assert not env.recording
env.close()

# check that only one video is recorded
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 1
shutil.rmtree("videos")


def test_rendering_works():
"""Test if render output is as expected when the env is wrapped with RecordVideo."""
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env._max_episode_steps = 20
env = RecordVideo(env, "videos")
env.reset()
n_steps = 10
def test_with_rgb_array_list(n_steps: int = 10):
"""Test if `env.render` works with RenderCollection and RecordVideo."""
# fyi, can't work as a `pytest.mark.parameterize`
env = RecordVideo(
RenderCollection(gym.make("CartPole-v1", render_mode="rgb_array")), "videos"
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(n_steps):
action = env.action_space.sample()
env.step(action)
env.step(env.action_space.sample())

render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == n_steps + 1
assert all(isinstance(render, np.ndarray) for render in render_out)
assert all(render.ndim == 3 for render in render_out)

render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == 0

env.close()
shutil.rmtree("videos")

# Test in reverse order
env = RenderCollection(
RecordVideo(gym.make("CartPole-v1", render_mode="rgb_array"), "videos")
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(n_steps):
env.step(env.action_space.sample())

render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == n_steps + 1
assert all(isinstance(render, np.ndarray) for render in render_out)
assert all(render.ndim == 3 for render in render_out)

render_out = env.render()
assert isinstance(render_out, List)
assert len(render_out) == 0

env.close()
shutil.rmtree("videos")

0 comments on commit f4c302d

Please sign in to comment.