From f4c302d8bb5bd923db812e6d9634082f2a981ed6 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 28 Nov 2023 11:33:55 +0000 Subject: [PATCH] Improve `RecordVideo` wrapper testing (#797) --- gymnasium/wrappers/rendering.py | 11 +-- tests/wrappers/test_record_video.py | 145 ++++++++++++++++++---------- 2 files changed, 100 insertions(+), 56 deletions(-) diff --git a/gymnasium/wrappers/rendering.py b/gymnasium/wrappers/rendering.py index 6fc1e451d..6b1afade9 100644 --- a/gymnasium/wrappers/rendering.py +++ b/gymnasium/wrappers/rendering.py @@ -235,7 +235,7 @@ def __init__( video_length: int = 0, name_prefix: str = "rl-video", fps: int | None = None, - disable_logger: bool = False, + disable_logger: bool = True, ): """Wrapper records videos of rollouts. @@ -247,9 +247,9 @@ def __init__( video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings - fps (int): The frame per second in the video. The default value is the one specified in the environment metadata. - If the environment metadata doesn't specify ``render_fps``, the value 30 is used. - disable_logger (bool): Whether to disable moviepy logger or not + fps (int): The frame per second in the video. Provides a custom video fps for environment, if ``None`` then + the environment metadata ``render_fps`` key is used if it exists, otherwise a default value of 30 is used. + disable_logger (bool): Whether to disable moviepy logger or not, default it is disabled """ gym.utils.RecordConstructorArgs.__init__( self, @@ -320,8 +320,7 @@ def _capture_frame(self): else: self.stop_recording() logger.warn( - "Recording stopped: expected type of frame returned by render ", - f"to be a numpy array, got instead {type(frame)}.", + f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." ) def reset( diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index b0659cdee..a328a3be0 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -3,48 +3,61 @@ import shutil from typing import List +import numpy as np +import pytest + import gymnasium as gym -from gymnasium.wrappers import RecordVideo +from gymnasium.wrappers import RecordVideo, RenderCollection -def test_record_video_using_default_trigger(): - """Test RecordVideo using the default episode trigger.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env = RecordVideo(env, "videos") - env.reset() - episode_count = 0 - for _ in range(199): +def test_video_folder_and_filenames( + video_folder="custom_video_folder", name_prefix="video-prefix" +): + env = gym.make("CartPole-v1", render_mode="rgb_array") + env = RecordVideo( + env, + video_folder=video_folder, + name_prefix=name_prefix, + episode_trigger=lambda x: x in [1, 4], + step_trigger=lambda x: x in [0, 25], + ) + + env.reset(seed=123) + env.action_space.seed(123) + for _ in range(100): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: env.reset() - episode_count += 1 - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert env.episode_trigger is not None - assert len(mp4_files) == sum( - env.episode_trigger(i) for i in range(episode_count + 1) - ) - shutil.rmtree("videos") + assert os.path.isdir(video_folder) + mp4_files = {file for file in os.listdir(video_folder) if file.endswith(".mp4")} + shutil.rmtree(video_folder) + assert mp4_files == { + "video-prefix-step-0.mp4", # step triggers + "video-prefix-step-25.mp4", + "video-prefix-episode-1.mp4", # episode triggers + "video-prefix-episode-4.mp4", + } + + +@pytest.mark.parametrize("episodic_trigger", [None, lambda x: x in [0, 3, 5, 10, 12]]) +def test_episodic_trigger(episodic_trigger): + """Test RecordVideo using the default episode trigger.""" + env = gym.make("CartPole-v1", render_mode="rgb_array") + env = RecordVideo(env, "videos", episode_trigger=episodic_trigger) -def test_record_video_while_rendering(): - """Test RecordVideo while calling render and using a _list render mode.""" - env = gym.make("FrozenLake-v1", render_mode="rgb_array_list") - env = RecordVideo(env, "videos") env.reset() episode_count = 0 for _ in range(199): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) - env.render() if terminated or truncated: env.reset() episode_count += 1 - env.close() + assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] assert env.episode_trigger is not None @@ -54,10 +67,9 @@ def test_record_video_while_rendering(): shutil.rmtree("videos") -def test_record_video_step_trigger(): +def test_step_trigger(): """Test RecordVideo defining step trigger function.""" - env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) - env._max_episode_steps = 20 + env = gym.make("CartPole-v1", render_mode="rgb_array") env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env.reset() for _ in range(199): @@ -72,69 +84,102 @@ def test_record_video_step_trigger(): assert len(mp4_files) == 2 -def test_record_video_both_trigger(): +def test_both_episodic_and_step_trigger(): """Test RecordVideo defining both step and episode trigger functions.""" - env = gym.make( - "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True - ) - env._max_episode_steps = 20 + env = gym.make("CartPole-v1", render_mode="rgb_array") env = RecordVideo( env, "videos", step_trigger=lambda x: x == 100, episode_trigger=lambda x: x == 0 or x == 3, ) - env.reset() - for _ in range(199): + # episode reset time steps: 0, 18, 44, 55, 80, 103, 117, 143, 173, 191 + # steps recorded: 0-18, 55-80, 100-103 + + env.reset(seed=123) + env.action_space.seed(123) + for i in range(199): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: env.reset() env.close() + assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] shutil.rmtree("videos") assert len(mp4_files) == 3 -def test_record_video_length(): +def test_video_length(video_length: int = 10): """Test if argument video_length of RecordVideo works properly.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env._max_episode_steps = 20 - env = RecordVideo(env, "videos", step_trigger=lambda x: x == 0, video_length=10) - env.reset() - for _ in range(10): + env = gym.make("CartPole-v1", render_mode="rgb_array") + env = RecordVideo( + env, "videos", step_trigger=lambda x: x == 0, video_length=video_length + ) + + env.reset(seed=123) + env.action_space.seed(123) + for _ in range(video_length): _, _, term, trunc, _ = env.step(env.action_space.sample()) if term or trunc: break + # check that the environment is still recording then take a step to take the number of steps > video length assert env.recording - action = env.action_space.sample() - env.step(action) + env.step(env.action_space.sample()) assert not env.recording env.close() + + # check that only one video is recorded assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] assert len(mp4_files) == 1 shutil.rmtree("videos") -def test_rendering_works(): - """Test if render output is as expected when the env is wrapped with RecordVideo.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env._max_episode_steps = 20 - env = RecordVideo(env, "videos") - env.reset() - n_steps = 10 +def test_with_rgb_array_list(n_steps: int = 10): + """Test if `env.render` works with RenderCollection and RecordVideo.""" + # fyi, can't work as a `pytest.mark.parameterize` + env = RecordVideo( + RenderCollection(gym.make("CartPole-v1", render_mode="rgb_array")), "videos" + ) + env.reset(seed=123) + env.action_space.seed(123) for _ in range(n_steps): - action = env.action_space.sample() - env.step(action) + env.step(env.action_space.sample()) + + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == n_steps + 1 + assert all(isinstance(render, np.ndarray) for render in render_out) + assert all(render.ndim == 3 for render in render_out) + + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == 0 + + env.close() + shutil.rmtree("videos") + + # Test in reverse order + env = RenderCollection( + RecordVideo(gym.make("CartPole-v1", render_mode="rgb_array"), "videos") + ) + env.reset(seed=123) + env.action_space.seed(123) + for _ in range(n_steps): + env.step(env.action_space.sample()) render_out = env.render() assert isinstance(render_out, List) assert len(render_out) == n_steps + 1 + assert all(isinstance(render, np.ndarray) for render in render_out) + assert all(render.ndim == 3 for render in render_out) + render_out = env.render() assert isinstance(render_out, List) assert len(render_out) == 0 + env.close() shutil.rmtree("videos")