From c25d3561d5fec5e1412afe875e9f9e11d8f23bea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:35:43 +0100 Subject: [PATCH 1/6] fix render mode attribute not being set by sb3 wrapper. Added a test to record video --- supersuit/vector/sb3_vector_wrapper.py | 1 + test/test_vector/test_render.py | 42 +++++++++++++++++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/supersuit/vector/sb3_vector_wrapper.py b/supersuit/vector/sb3_vector_wrapper.py index 1130864..c34a0e0 100644 --- a/supersuit/vector/sb3_vector_wrapper.py +++ b/supersuit/vector/sb3_vector_wrapper.py @@ -12,6 +12,7 @@ def __init__(self, venv): self.num_envs = venv.num_envs self.observation_space = venv.observation_space self.action_space = venv.action_space + self.render_mode = venv.render_mode self.reset_infos = [] def reset(self, seed=None, options=None): diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 78f4067..3acfcd1 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -1,21 +1,49 @@ from pettingzoo.butterfly import pistonball_v6 +import supersuit as ss +from stable_baselines3.common.vec_env import VecVideoRecorder -from supersuit import pettingzoo_env_to_vec_env_v1 +def schedule(episode_idx): + print(episode_idx) + return episode_idx <= 1 -def make_env(): - env = pistonball_v6.parallel_env() - env = pettingzoo_env_to_vec_env_v1(env) - return env + +def make_record_env(): + env = pistonball_v6.parallel_env(render_mode="rgb_array") + print(env.render_mode) + env = ss.pettingzoo_env_to_vec_env_v1(env) + envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") + # envs.render_mode = "rgb_array" + envs = VecVideoRecorder(envs, f".", schedule) + return envs + + +def record_video_test(): + envs = make_record_env() + envs.reset() + for _ in range(100): + envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) + envs.close() + + +record_video_test() + + +# def make_env(): +# env = pistonball_v6.parallel_env(render_mode="rgb_array") +# env = ss.pettingzoo_env_to_vec_env_v1(env) +# return env # unfortunately this test does not pass # def test_vector_render_multiproc(): # env = make_env() # num_envs = 3 -# venv = concat_vec_envs_v1(env, num_envs, num_cpus=num_envs, base_class='stable_baselines3') +# venv = ss.concat_vec_envs_v1( +# env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" +# ) # venv.reset() -# arr = venv.render(mode="rgb_array") +# arr = venv.render() # venv.reset() # assert len(arr.shape) == 3 and arr.shape[2] == 3 # venv.reset() From 730019704c7d6fb2ede97f61f65e4beb4c1950ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:40:32 +0100 Subject: [PATCH 2/6] Removed comment and save the test video in /tmp --- test/test_vector/test_render.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 3acfcd1..6a3a842 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -13,8 +13,7 @@ def make_record_env(): print(env.render_mode) env = ss.pettingzoo_env_to_vec_env_v1(env) envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") - # envs.render_mode = "rgb_array" - envs = VecVideoRecorder(envs, f".", schedule) + envs = VecVideoRecorder(envs, f"/tmp", schedule) return envs From bb9a52c15eb758a5a3d6db360ea3f46825cfca18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Dumas?= Date: Wed, 17 Jan 2024 14:41:51 +0100 Subject: [PATCH 3/6] reformat --- test/test_vector/test_render.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 6a3a842..b6898c0 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -8,7 +8,7 @@ def schedule(episode_idx): return episode_idx <= 1 -def make_record_env(): +def make_sb3_record_env(): env = pistonball_v6.parallel_env(render_mode="rgb_array") print(env.render_mode) env = ss.pettingzoo_env_to_vec_env_v1(env) @@ -17,17 +17,14 @@ def make_record_env(): return envs -def record_video_test(): - envs = make_record_env() +def test_record_video_sb3(): + envs = make_sb3_record_env() envs.reset() for _ in range(100): envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) envs.close() -record_video_test() - - # def make_env(): # env = pistonball_v6.parallel_env(render_mode="rgb_array") # env = ss.pettingzoo_env_to_vec_env_v1(env) From ef2cf68fc9e734bdcc1df7008b72297523ea9d5d Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 12:51:18 -0500 Subject: [PATCH 4/6] Fix sb3 requirements for testing and uncomment other render test (works locally) --- pyproject.toml | 2 +- test/test_vector/test_render.py | 43 ++++++++++++++++----------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de31360..8578c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dynamic = ["version"] [project.optional-dependencies] # Update dependencies in `all` if any are added or removed -testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest"] +testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0"] [project.urls] Homepage = "https://farama.org" diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index b6898c0..cc1913e 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -25,25 +25,24 @@ def test_record_video_sb3(): envs.close() -# def make_env(): -# env = pistonball_v6.parallel_env(render_mode="rgb_array") -# env = ss.pettingzoo_env_to_vec_env_v1(env) -# return env - - -# unfortunately this test does not pass -# def test_vector_render_multiproc(): -# env = make_env() -# num_envs = 3 -# venv = ss.concat_vec_envs_v1( -# env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" -# ) -# venv.reset() -# arr = venv.render() -# venv.reset() -# assert len(arr.shape) == 3 and arr.shape[2] == 3 -# venv.reset() -# try: -# venv.close() -# except RuntimeError: -# pass +def make_env(): + env = pistonball_v6.parallel_env(render_mode="rgb_array") + env = ss.pettingzoo_env_to_vec_env_v1(env) + return env + + +def test_vector_render_multiproc(): + env = make_env() + num_envs = 1 + venv = ss.concat_vec_envs_v1( + env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" + ) + venv.reset() + arr = venv.render() + venv.reset() + assert len(arr.shape) == 3 and arr.shape[2] == 3 + venv.reset() + try: + venv.close() + except RuntimeError: + pass From b3ab650c65f318bc5047bdd163bdd1b32492981a Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 12:55:05 -0500 Subject: [PATCH 5/6] Fix pre-commit --- test/test_vector/test_render.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index cc1913e..ec03e0f 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -1,7 +1,8 @@ from pettingzoo.butterfly import pistonball_v6 -import supersuit as ss from stable_baselines3.common.vec_env import VecVideoRecorder +import supersuit as ss + def schedule(episode_idx): print(episode_idx) @@ -13,7 +14,7 @@ def make_sb3_record_env(): print(env.render_mode) env = ss.pettingzoo_env_to_vec_env_v1(env) envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") - envs = VecVideoRecorder(envs, f"/tmp", schedule) + envs = VecVideoRecorder(envs, "/tmp", schedule) return envs From f6ab50456eadfbb1f942ca14ce2e9e7bf8e41e3f Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 13:43:26 -0500 Subject: [PATCH 6/6] Add moviepy requirement (error with gymnasium recorder) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8578c53..7fc1ca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dynamic = ["version"] [project.optional-dependencies] # Update dependencies in `all` if any are added or removed -testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0"] +testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"] [project.urls] Homepage = "https://farama.org"