Skip to content

Commit

Permalink
Fix render_mode attribute not being set by sb3 wrapper (#243)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Dumas <[email protected]>
  • Loading branch information
elliottower and Butanium authored Jan 18, 2024
1 parent 035cf0f commit 5b7b2ed
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dynamic = ["version"]

[project.optional-dependencies]
# Update dependencies in `all` if any are added or removed
testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest"]
testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"]

[project.urls]
Homepage = "https://farama.org"
Expand Down
1 change: 1 addition & 0 deletions supersuit/vector/sb3_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, venv):
self.num_envs = venv.num_envs
self.observation_space = venv.observation_space
self.action_space = venv.action_space
self.render_mode = venv.render_mode
self.reset_infos = []

def reset(self, seed=None, options=None):
Expand Down
58 changes: 41 additions & 17 deletions test/test_vector/test_render.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,49 @@
from pettingzoo.butterfly import pistonball_v6
from stable_baselines3.common.vec_env import VecVideoRecorder

from supersuit import pettingzoo_env_to_vec_env_v1
import supersuit as ss


def schedule(episode_idx):
print(episode_idx)
return episode_idx <= 1


def make_sb3_record_env():
env = pistonball_v6.parallel_env(render_mode="rgb_array")
print(env.render_mode)
env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3")
envs = VecVideoRecorder(envs, "/tmp", schedule)
return envs


def test_record_video_sb3():
envs = make_sb3_record_env()
envs.reset()
for _ in range(100):
envs.step([envs.action_space.sample() for _ in range(envs.num_envs)])
envs.close()


def make_env():
env = pistonball_v6.parallel_env()
env = pettingzoo_env_to_vec_env_v1(env)
env = pistonball_v6.parallel_env(render_mode="rgb_array")
env = ss.pettingzoo_env_to_vec_env_v1(env)
return env


# unfortunately this test does not pass
# def test_vector_render_multiproc():
# env = make_env()
# num_envs = 3
# venv = concat_vec_envs_v1(env, num_envs, num_cpus=num_envs, base_class='stable_baselines3')
# venv.reset()
# arr = venv.render(mode="rgb_array")
# venv.reset()
# assert len(arr.shape) == 3 and arr.shape[2] == 3
# venv.reset()
# try:
# venv.close()
# except RuntimeError:
# pass
def test_vector_render_multiproc():
env = make_env()
num_envs = 1
venv = ss.concat_vec_envs_v1(
env, num_envs, num_cpus=num_envs, base_class="stable_baselines3"
)
venv.reset()
arr = venv.render()
venv.reset()
assert len(arr.shape) == 3 and arr.shape[2] == 3
venv.reset()
try:
venv.close()
except RuntimeError:
pass

0 comments on commit 5b7b2ed

Please sign in to comment.