diff --git a/pyproject.toml b/pyproject.toml index de31360..7fc1ca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dynamic = ["version"] [project.optional-dependencies] # Update dependencies in `all` if any are added or removed -testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest"] +testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"] [project.urls] Homepage = "https://farama.org" diff --git a/supersuit/vector/sb3_vector_wrapper.py b/supersuit/vector/sb3_vector_wrapper.py index 1130864..c34a0e0 100644 --- a/supersuit/vector/sb3_vector_wrapper.py +++ b/supersuit/vector/sb3_vector_wrapper.py @@ -12,6 +12,7 @@ def __init__(self, venv): self.num_envs = venv.num_envs self.observation_space = venv.observation_space self.action_space = venv.action_space + self.render_mode = venv.render_mode self.reset_infos = [] def reset(self, seed=None, options=None): diff --git a/test/test_vector/test_render.py b/test/test_vector/test_render.py index 78f4067..ec03e0f 100644 --- a/test/test_vector/test_render.py +++ b/test/test_vector/test_render.py @@ -1,25 +1,49 @@ from pettingzoo.butterfly import pistonball_v6 +from stable_baselines3.common.vec_env import VecVideoRecorder -from supersuit import pettingzoo_env_to_vec_env_v1 +import supersuit as ss + + +def schedule(episode_idx): + print(episode_idx) + return episode_idx <= 1 + + +def make_sb3_record_env(): + env = pistonball_v6.parallel_env(render_mode="rgb_array") + print(env.render_mode) + env = ss.pettingzoo_env_to_vec_env_v1(env) + envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") + envs = VecVideoRecorder(envs, "/tmp", schedule) + return envs + + +def test_record_video_sb3(): + envs = make_sb3_record_env() + envs.reset() + for _ in range(100): + envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) + envs.close() def make_env(): - env = pistonball_v6.parallel_env() - env = pettingzoo_env_to_vec_env_v1(env) + env = pistonball_v6.parallel_env(render_mode="rgb_array") + env = ss.pettingzoo_env_to_vec_env_v1(env) return env -# unfortunately this test does not pass -# def test_vector_render_multiproc(): -# env = make_env() -# num_envs = 3 -# venv = concat_vec_envs_v1(env, num_envs, num_cpus=num_envs, base_class='stable_baselines3') -# venv.reset() -# arr = venv.render(mode="rgb_array") -# venv.reset() -# assert len(arr.shape) == 3 and arr.shape[2] == 3 -# venv.reset() -# try: -# venv.close() -# except RuntimeError: -# pass +def test_vector_render_multiproc(): + env = make_env() + num_envs = 1 + venv = ss.concat_vec_envs_v1( + env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" + ) + venv.reset() + arr = venv.render() + venv.reset() + assert len(arr.shape) == 3 and arr.shape[2] == 3 + venv.reset() + try: + venv.close() + except RuntimeError: + pass