Skip to content

Commit

Permalink
fix render mode attribute not being set by sb3 wrapper.
Browse files Browse the repository at this point in the history
Added a test to record video
  • Loading branch information
Butanium committed Jan 17, 2024
1 parent 035cf0f commit c25d356
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
1 change: 1 addition & 0 deletions supersuit/vector/sb3_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, venv):
self.num_envs = venv.num_envs
self.observation_space = venv.observation_space
self.action_space = venv.action_space
self.render_mode = venv.render_mode
self.reset_infos = []

def reset(self, seed=None, options=None):
Expand Down
42 changes: 35 additions & 7 deletions test/test_vector/test_render.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,49 @@
from pettingzoo.butterfly import pistonball_v6
import supersuit as ss
from stable_baselines3.common.vec_env import VecVideoRecorder

from supersuit import pettingzoo_env_to_vec_env_v1

def schedule(episode_idx):
print(episode_idx)
return episode_idx <= 1

def make_env():
env = pistonball_v6.parallel_env()
env = pettingzoo_env_to_vec_env_v1(env)
return env

def make_record_env():
env = pistonball_v6.parallel_env(render_mode="rgb_array")
print(env.render_mode)
env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3")
# envs.render_mode = "rgb_array"
envs = VecVideoRecorder(envs, f".", schedule)
return envs


def record_video_test():
envs = make_record_env()
envs.reset()
for _ in range(100):
envs.step([envs.action_space.sample() for _ in range(envs.num_envs)])
envs.close()


record_video_test()


# def make_env():
# env = pistonball_v6.parallel_env(render_mode="rgb_array")
# env = ss.pettingzoo_env_to_vec_env_v1(env)
# return env


# unfortunately this test does not pass
# def test_vector_render_multiproc():
# env = make_env()
# num_envs = 3
# venv = concat_vec_envs_v1(env, num_envs, num_cpus=num_envs, base_class='stable_baselines3')
# venv = ss.concat_vec_envs_v1(
# env, num_envs, num_cpus=num_envs, base_class="stable_baselines3"
# )
# venv.reset()
# arr = venv.render(mode="rgb_array")
# arr = venv.render()
# venv.reset()
# assert len(arr.shape) == 3 and arr.shape[2] == 3
# venv.reset()
Expand Down

0 comments on commit c25d356

Please sign in to comment.