Skip to content

Commit

Permalink
Option to save frames to disk in video recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
rk1a committed Sep 18, 2023
1 parent adcc405 commit 1391640
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
37 changes: 36 additions & 1 deletion gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from __future__ import annotations

import os
import shutil
from copy import deepcopy
from typing import Any, Callable, List, SupportsFloat

import imageio
import numpy as np

import gymnasium as gym
Expand Down Expand Up @@ -222,6 +224,21 @@ class RecordVideo(
>>> len(os.listdir("./save_videos3"))
2
Examples - Run 1 episode, record everything, but write frames to disk while recording:
>>> import os
>>> import gymnasium as gym
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> env = RecordVideo(env, video_folder="./save_videos4", frames_to_disk=True, disable_logger=True)
>>> for i in range(3):
... termination, truncation = False, False
... _ = env.reset(seed=123)
... while not (termination or truncation):
... obs, rew, termination, truncation, info = env.step(env.action_space.sample())
...
>>> env.close()
>>> len(os.listdir("./save_videos4"))
1
Change logs:
* v0.25.0 - Initially added to replace ``wrappers.monitoring.VideoRecorder``
"""
Expand All @@ -236,6 +253,7 @@ def __init__(
name_prefix: str = "rl-video",
fps: int | None = None,
disable_logger: bool = False,
frames_to_disk: bool = False,
):
"""Wrapper records videos of rollouts.
Expand All @@ -250,6 +268,7 @@ def __init__(
fps (int): The frame per second in the video. The default value is the one specified in the environment metadata.
If the environment metadata doesn't specify ``render_fps``, the value 30 is used.
disable_logger (bool): Whether to disable moviepy logger or not
frames_to_disk (bool): Whether to save frames to disk to reduce RAM usage.
"""
gym.utils.RecordConstructorArgs.__init__(
self,
Expand All @@ -259,6 +278,7 @@ def __init__(
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
frames_to_disk=frames_to_disk,
)
gym.Wrapper.__init__(self, env)

Expand All @@ -276,6 +296,7 @@ def __init__(
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.disable_logger = disable_logger
self.frames_to_disk = frames_to_disk

self.video_folder = os.path.abspath(video_folder)
if os.path.isdir(self.video_folder):
Expand All @@ -284,6 +305,9 @@ def __init__(
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)
if self.frames_to_disk:
self.frames_folder = os.path.join(self.video_folder, "frames")
os.makedirs(self.frames_folder, exist_ok=True)

if fps is None:
fps = self.metadata.get("render_fps", 30)
Expand Down Expand Up @@ -316,7 +340,16 @@ def _capture_frame(self):
frame = frame[-1]

if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
if self.frames_to_disk:
frame_path = os.path.join(
self.frames_folder,
f"frame_{str(len(self.recorded_frames))}.png",
)
# Write frame and remember its path
imageio.imwrite(frame_path, frame)
self.recorded_frames.append(frame_path)
else:
self.recorded_frames.append(frame)
else:
self.stop_recording()
logger.warn(
Expand Down Expand Up @@ -406,6 +439,8 @@ def stop_recording(self):
path = os.path.join(self.video_folder, f"{self._video_name}.mp4")
clip.write_videofile(path, logger=moviepy_logger)

if self.frames_to_disk:
shutil.rmtree(self.frames_folder)
self.recorded_frames = []
self.recording = False
self._video_name = None
Expand Down
18 changes: 18 additions & 0 deletions tests/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def test_record_video_while_rendering():
shutil.rmtree("videos")


def test_record_video_frames_to_disk():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(
env, "videos", episode_trigger=lambda x: x % 100 == 0, frames_to_disk=True
)
ob_space = env.observation_space
obs, info = env.reset()
assert os.path.isdir(os.path.join("videos", "frames"))
assert os.path.isfile(os.path.join("videos", "frames", "frame_0.png"))
env.close()
assert os.path.isdir("videos")
assert any(file.endswith(".mp4") for file in os.listdir("videos"))
assert not os.path.isdir(os.path.join("videos", "frames"))
shutil.rmtree("videos")
assert ob_space.contains(obs)
assert isinstance(info, dict)


def test_record_video_step_trigger():
"""Test RecordVideo defining step trigger function."""
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
Expand Down

0 comments on commit 1391640

Please sign in to comment.