From c25d3561d5fec5e1412afe875e9f9e11d8f23bea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:35:43 +0100 Subject: [PATCH 1/3] fix render mode attribute not being set by sb3 wrapper. Added a test to record video --- supersuit/vector/sb3_vector_wrapper.py | 1 + test/test_vector/test_render.py | 42 +++++++++++++++++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/supersuit/vector/sb3_vector_wrapper.py b/supersuit/vector/sb3_vector_wrapper.py index 1130864..c34a0e0 100644 --- a/supersuit/vector/sb3_vector_wrapper.py +++ b/supersuit/vector/sb3_vector_wrapper.py @@ -12,6 +12,7 @@ def __init__(self, venv): self.num_envs = venv.num_envs self.observation_space = venv.observation_space self.action_space = venv.action_space + self.render_mode = venv.render_mode self.reset_infos = [] def reset(self, seed=None, options=None): diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 78f4067..3acfcd1 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -1,21 +1,49 @@ from pettingzoo.butterfly import pistonball_v6 +import supersuit as ss +from stable_baselines3.common.vec_env import VecVideoRecorder -from supersuit import pettingzoo_env_to_vec_env_v1 +def schedule(episode_idx): + print(episode_idx) + return episode_idx <= 1 -def make_env(): - env = pistonball_v6.parallel_env() - env = pettingzoo_env_to_vec_env_v1(env) - return env + +def make_record_env(): + env = pistonball_v6.parallel_env(render_mode="rgb_array") + print(env.render_mode) + env = ss.pettingzoo_env_to_vec_env_v1(env) + envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") + # envs.render_mode = "rgb_array" + envs = VecVideoRecorder(envs, f".", schedule) + return envs + + +def record_video_test(): + envs = make_record_env() + envs.reset() + for _ in range(100): + envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) + envs.close() + + +record_video_test() + + +# def make_env(): +# env = pistonball_v6.parallel_env(render_mode="rgb_array") +# env = ss.pettingzoo_env_to_vec_env_v1(env) +# return env # unfortunately this test does not pass # def test_vector_render_multiproc(): # env = make_env() # num_envs = 3 -# venv = concat_vec_envs_v1(env, num_envs, num_cpus=num_envs, base_class='stable_baselines3') +# venv = ss.concat_vec_envs_v1( +# env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" +# ) # venv.reset() -# arr = venv.render(mode="rgb_array") +# arr = venv.render() # venv.reset() # assert len(arr.shape) == 3 and arr.shape[2] == 3 # venv.reset() From 730019704c7d6fb2ede97f61f65e4beb4c1950ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:40:32 +0100 Subject: [PATCH 2/3] Removed comment and save the test video in /tmp --- test/test_vector/test_render.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 3acfcd1..6a3a842 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -13,8 +13,7 @@ def make_record_env(): print(env.render_mode) env = ss.pettingzoo_env_to_vec_env_v1(env) envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") - # envs.render_mode = "rgb_array" - envs = VecVideoRecorder(envs, f".", schedule) + envs = VecVideoRecorder(envs, f"/tmp", schedule) return envs From bb9a52c15eb758a5a3d6db360ea3f46825cfca18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:41:51 +0100 Subject: [PATCH 3/3] reformat --- test/test_vector/test_render.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 6a3a842..b6898c0 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -8,7 +8,7 @@ def schedule(episode_idx): return episode_idx <= 1 -def make_record_env(): +def make_sb3_record_env(): env = pistonball_v6.parallel_env(render_mode="rgb_array") print(env.render_mode) env = ss.pettingzoo_env_to_vec_env_v1(env) @@ -17,17 +17,14 @@ def make_record_env(): return envs -def record_video_test(): - envs = make_record_env() +def test_record_video_sb3(): + envs = make_sb3_record_env() envs.reset() for _ in range(100): envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) envs.close() -record_video_test() - - # def make_env(): # env = pistonball_v6.parallel_env(render_mode="rgb_array") # env = ss.pettingzoo_env_to_vec_env_v1(env)