From 019a59311714e4062b0d50afc94e291ef857d423 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Tue, 17 Jan 2023 09:13:43 +0100 Subject: [PATCH] ENH: implement RecordVideoV0 (#246) --- gymnasium/experimental/wrappers/rendering.py | 199 +++++++++++++++++- gymnasium/wrappers/record_video.py | 2 +- .../wrappers/test_record_video.py | 165 +++++++++++++++ 3 files changed, 362 insertions(+), 4 deletions(-) diff --git a/gymnasium/experimental/wrappers/rendering.py b/gymnasium/experimental/wrappers/rendering.py index 860a21e25..6b146e14d 100644 --- a/gymnasium/experimental/wrappers/rendering.py +++ b/gymnasium/experimental/wrappers/rendering.py @@ -6,12 +6,14 @@ """ from __future__ import annotations +import os from copy import deepcopy -from typing import Any, SupportsFloat +from typing import Any, Callable, List, SupportsFloat import numpy as np import gymnasium as gym +from gymnasium import error, logger from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled @@ -79,9 +81,200 @@ def render(self) -> RenderFrame | list[RenderFrame] | None: class RecordVideoV0(gym.Wrapper): - """Record a video of an environment.""" + """This wrapper records videos of rollouts. + + Usually, you only want to record episodes intermittently, say every hundredth episode. + To do this, you can specify ``episode_trigger`` or ``step_trigger``. + They should be functions returning a boolean that indicates whether a recording should be started at the + current episode or step, respectively. + If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed, + i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and + then every 1000 episodes. + By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed + length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``. + This wrapper uses the value `fps` from metadata as the number of frames per second; + if `fps` is not defined in metadata, the default value 30 is used. + """ + + def __init__( + self, + env: gym.Env, + video_folder: str, + episode_trigger: Callable[[int], bool] = None, + step_trigger: Callable[[int], bool] = None, + video_length: int = 0, + name_prefix: str = "rl-video", + disable_logger: bool = False, + ): + """Wrapper records videos of rollouts. + + Args: + env: The environment that will be wrapped + video_folder (str): The folder where the recordings will be stored + episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode + step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step + video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. + Otherwise, snippets of the specified length are captured + name_prefix (str): Will be prepended to the filename of the recordings + disable_logger (bool): Whether to disable moviepy logger or not + + """ + super().__init__(env) + try: + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled( + "MoviePy is not installed, run `pip install moviepy`" + ) from e + + if env.render_mode in {None, "human", "ansi"}: + raise ValueError( + f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.", + "Initialize your environment with a render_mode that returns an image, such as rgb_array.", + ) + + if episode_trigger is None and step_trigger is None: + + def capped_cubic_video_schedule(episode_id: int) -> bool: + if episode_id < 1000: + return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id + else: + return episode_id % 1000 == 0 + + episode_trigger = capped_cubic_video_schedule + + self.episode_trigger = episode_trigger + self.step_trigger = step_trigger + self.disable_logger = disable_logger + + self.video_folder = os.path.abspath(video_folder) + if os.path.isdir(self.video_folder): + logger.warn( + f"Overwriting existing videos at {self.video_folder} folder " + f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)" + ) + os.makedirs(self.video_folder, exist_ok=True) + + self.name_prefix = name_prefix + self._video_name = None + self.frames_per_sec = self.metadata.get("render_fps", 30) + self.video_length = video_length if video_length != 0 else float("inf") + self.recording = False + self.recorded_frames = [] + self.render_history = [] + + self.step_id = -1 + self.episode_id = -1 + + def _capture_frame(self): + assert self.recording, "Cannot capture a frame, recording wasn't started." + + frame = self.env.render() + if isinstance(frame, List): + if len(frame) == 0: # render was called + return + self.render_history += frame + frame = frame[-1] + + if isinstance(frame, np.ndarray): + self.recorded_frames.append(frame) + else: + self.stop_recording() + logger.warn( + "Recording stopped: expected type of frame returned by render ", + f"to be a numpy array, got instead {type(frame)}.", + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the environment and eventually starts a new recording.""" + obs, info = super().reset(seed=seed, options=options) + self.episode_id += 1 + + if self.recording and self.video_length == float("inf"): + self.stop_recording() + + if self.episode_trigger and self.episode_trigger(self.episode_id): + self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}") + if self.recording: + self._capture_frame() + if len(self.recorded_frames) > self.video_length: + self.stop_recording() + + return obs, info + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment using action, recording observations if :attr:`self.recording`.""" + obs, rew, terminated, truncated, info = self.env.step(action) + self.step_id += 1 + + if self.step_trigger and self.step_trigger(self.step_id): + self.start_recording(f"{self.name_prefix}-step-{self.step_id}") + if self.recording: + self._capture_frame() + + if len(self.recorded_frames) > self.video_length: + self.stop_recording() + + return obs, rew, terminated, truncated, info + + def start_recording(self, video_name): + """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" + if self.recording: + self.stop_recording() + + self.recording = True + self._video_name = video_name + + def stop_recording(self): + """Stop current recording and saves the video.""" + assert self.recording, "stop_recording was called, but no recording was started" + + if len(self.recorded_frames) == 0: + logger.warn("Ignored saving a video as there were zero frames to save.") + else: + try: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + except ImportError as e: + raise error.DependencyNotInstalled( + "MoviePy is not installed, run `pip install moviepy`" + ) from e + + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + moviepy_logger = None if self.disable_logger else "bar" + path = os.path.join(self.video_folder, f"{self._video_name}.mp4") + clip.write_videofile(path, logger=moviepy_logger) + + self.recorded_frames = [] + self.recording = False + self._video_name = None + + def render(self): + """Compute the render frames as specified by render_mode attribute during initialization of the environment.""" + render_out = super().render() + if self.recording and isinstance(render_out, List): + self.recorded_frames += render_out + + if len(self.render_history) > 0: + tmp_history = self.render_history + self.render_history = [] + return tmp_history + render_out + else: + return render_out + + def close(self): + """Closes the wrapper then the video recorder.""" + super().close() + if self.recording: + self.stop_recording() - pass + def __del__(self): + """Warn the user in case last video wasn't saved.""" + if len(self.recorded_frames) > 0: + logger.warn("Unable to save last video! Did you call close()?") class HumanRenderingV0(gym.Wrapper): diff --git a/gymnasium/wrappers/record_video.py b/gymnasium/wrappers/record_video.py index 6cf32d279..c5236ce2d 100644 --- a/gymnasium/wrappers/record_video.py +++ b/gymnasium/wrappers/record_video.py @@ -10,7 +10,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool: """The default episode trigger. - This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... + This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... Args: episode_id: The episode number diff --git a/tests/experimental/wrappers/test_record_video.py b/tests/experimental/wrappers/test_record_video.py index d79f672b6..20b35d6e2 100644 --- a/tests/experimental/wrappers/test_record_video.py +++ b/tests/experimental/wrappers/test_record_video.py @@ -1 +1,166 @@ """Test suite for RecordVideoV0.""" +import os +import shutil +from typing import List + +import gymnasium as gym +from gymnasium.experimental.wrappers import RecordVideoV0 + + +def test_record_video_using_default_trigger(): + """Test RecordVideo using the default episode trigger.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env = RecordVideoV0(env, "videos") + env.reset() + episode_count = 0 + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: + env.reset() + episode_count += 1 + + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert len(mp4_files) == sum( + env.episode_trigger(i) for i in range(episode_count + 1) + ) + shutil.rmtree("videos") + + +def test_record_video_while_rendering(): + """Test RecordVideo while calling render and using a _list render mode.""" + env = gym.make("FrozenLake-v1", render_mode="rgb_array_list") + env = RecordVideoV0(env, "videos") + env.reset() + episode_count = 0 + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + env.render() + if terminated or truncated: + env.reset() + episode_count += 1 + + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert len(mp4_files) == sum( + env.episode_trigger(i) for i in range(episode_count + 1) + ) + shutil.rmtree("videos") + + +def test_record_video_step_trigger(): + """Test RecordVideo defining step trigger function.""" + env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) + env._max_episode_steps = 20 + env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0) + env.reset() + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: + env.reset() + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + shutil.rmtree("videos") + assert len(mp4_files) == 2 + + +def test_record_video_both_trigger(): + """Test RecordVideo defining both step and episode trigger functions.""" + env = gym.make( + "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True + ) + env._max_episode_steps = 20 + env = RecordVideoV0( + env, + "videos", + step_trigger=lambda x: x == 100, + episode_trigger=lambda x: x == 0 or x == 3, + ) + env.reset() + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: + env.reset() + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + shutil.rmtree("videos") + assert len(mp4_files) == 3 + + +def test_record_video_length(): + """Test if argument video_length of RecordVideo works properly.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env._max_episode_steps = 20 + env = RecordVideoV0(env, "videos", step_trigger=lambda x: x == 0, video_length=10) + env.reset() + for _ in range(10): + action = env.action_space.sample() + env.step(action) + + assert env.recording + action = env.action_space.sample() + env.step(action) + assert not env.recording + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert len(mp4_files) == 1 + shutil.rmtree("videos") + + +def test_rendering_works(): + """Test if render output is as expected when the env is wrapped with RecordVideo.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env._max_episode_steps = 20 + env = RecordVideoV0(env, "videos") + env.reset() + n_steps = 10 + for _ in range(n_steps): + action = env.action_space.sample() + env.step(action) + + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == n_steps + 1 + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == 0 + env.close() + shutil.rmtree("videos") + + +def make_env(gym_id, idx, **kwargs): + """Utility function to make an env and wrap it with RecordVideo only the first time.""" + + def thunk(): + env = gym.make(gym_id, disable_env_checker=True, **kwargs) + env._max_episode_steps = 20 + if idx == 0: + env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0) + return env + + return thunk + + +def test_record_video_within_vector(): + """Test RecordVideo used as env of SyncVectorEnv.""" + envs = gym.vector.SyncVectorEnv( + [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(2)] + ) + envs = gym.wrappers.RecordEpisodeStatistics(envs) + envs.reset() + for i in range(199): + _, _, _, _, infos = envs.step(envs.action_space.sample()) + + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert len(mp4_files) == 2 + shutil.rmtree("videos")