Skip to content

Commit

Permalink
Option to save frames to disk in video recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
rk1a committed Sep 17, 2023
1 parent baf7807 commit 56229c4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
22 changes: 21 additions & 1 deletion gymnasium/wrappers/monitoring/video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import json
import os
import os.path
import shutil
import tempfile
from typing import List, Optional

import imageio

from gymnasium import error, logger


Expand All @@ -25,6 +28,7 @@ def __init__(
enabled: bool = True,
base_path: Optional[str] = None,
disable_logger: bool = False,
frames_to_disk: bool = False,
):
"""Video recorder renders a nice movie of a rollout, frame by frame.
Expand All @@ -35,6 +39,7 @@ def __init__(
enabled (bool): Whether to actually record video, or just no-op (for convenience)
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
disable_logger (bool): Whether to disable moviepy logger or not.
frames_to_disk (bool): Whether to save frames to disk to reduce RAM usage.
Raises:
Error: You can pass at most one of `path` or `base_path`
Expand All @@ -43,6 +48,7 @@ def __init__(
self._async = env.metadata.get("semantics.async")
self.enabled = enabled
self.disable_logger = disable_logger
self.frames_to_disk = frames_to_disk
self._closed = False

self.render_history = []
Expand Down Expand Up @@ -82,6 +88,9 @@ def __init__(
with tempfile.NamedTemporaryFile(suffix=required_ext) as f:
path = f.name
self.path = path
if self.frames_to_disk:
self.frames_dir = os.path.join(os.path.dirname(self.path), "frames")
os.makedirs(self.frames_dir, exist_ok=True)

path_base, actual_ext = os.path.splitext(self.path)

Expand Down Expand Up @@ -136,7 +145,16 @@ def capture_frame(self):
)
self.broken = True
else:
self.recorded_frames.append(frame)
if self.frames_to_disk:
frame_path = os.path.join(
self.frames_dir,
f"frame_{str(len(self.recorded_frames))}.png",
)
# Write frame and remember its path
imageio.imwrite(frame_path, frame)
self.recorded_frames.append(frame_path)
else:
self.recorded_frames.append(frame)

def close(self):
"""Flush all data to disk and close any open frame encoders."""
Expand All @@ -155,6 +173,8 @@ def close(self):
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
moviepy_logger = None if self.disable_logger else "bar"
clip.write_videofile(self.path, logger=moviepy_logger)
if self.frames_to_disk:
shutil.rmtree(self.frames_dir)
else:
# No frames captured. Set metadata.
if self.metadata is None:
Expand Down
5 changes: 5 additions & 0 deletions gymnasium/wrappers/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
video_length: int = 0,
name_prefix: str = "rl-video",
disable_logger: bool = False,
frames_to_disk: bool = False,
):
"""Wrapper records videos of rollouts.
Expand All @@ -58,6 +59,7 @@ def __init__(
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not.
frames_to_disk (bool): Whether to save frames to disk to reduce RAM usage.
"""
gym.utils.RecordConstructorArgs.__init__(
self,
Expand All @@ -67,6 +69,7 @@ def __init__(
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
frames_to_disk=frames_to_disk,
)
gym.Wrapper.__init__(self, env)

Expand All @@ -87,6 +90,7 @@ def __init__(
self.step_trigger = step_trigger
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
self.disable_logger = disable_logger
self.frames_to_disk = frames_to_disk

self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
Expand Down Expand Up @@ -143,6 +147,7 @@ def start_video_recorder(self):
base_path=base_path,
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
disable_logger=self.disable_logger,
frames_to_disk=self.frames_to_disk,
)

self.video_recorder.capture_frame()
Expand Down
18 changes: 18 additions & 0 deletions tests/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def test_record_video_reset():
assert isinstance(info, dict)


def test_record_video_frames_to_disk():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(
env, "videos", step_trigger=lambda x: x % 100 == 0, frames_to_disk=True
)
ob_space = env.observation_space
obs, info = env.reset()
assert os.path.isdir(os.path.join("videos", "frames"))
assert os.path.isfile(os.path.join("videos", "frames", "frame_0.png"))
env.close()
assert os.path.isdir("videos")
assert any(file.endswith(".mp4") for file in os.listdir("videos"))
assert not os.path.isdir(os.path.join("videos", "frames"))
shutil.rmtree("videos")
assert ob_space.contains(obs)
assert isinstance(info, dict)


def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env._max_episode_steps = 20
Expand Down

0 comments on commit 56229c4

Please sign in to comment.