Skip to content

Commit

Permalink
Fix wrappers.vector.HumanRendering (#1069)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Towers <[email protected]>
  • Loading branch information
RogerJL and pseudo-rnd-thoughts authored May 31, 2024
1 parent 0607994 commit 04fb345
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 31 deletions.
59 changes: 28 additions & 31 deletions gymnasium/wrappers/vector/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,39 @@ def _render_frame(self):
width_ratio = subenv_size[0] / self.screen_size[0]
height_ratio = subenv_size[1] / self.screen_size[1]

rows, cols = 1, 1
while rows * cols < self.num_envs:
row_ratio = rows * height_ratio
col_ratio = cols * width_ratio
num_rows, num_cols = 1, 1
while num_rows * num_cols < self.num_envs:
row_ratio = num_rows * height_ratio
col_ratio = num_cols * width_ratio

if row_ratio == col_ratio:
rows, cols = rows + 1, cols + 1
num_rows, num_cols = num_rows + 1, num_cols + 1
elif row_ratio > col_ratio:
cols += 1
num_cols += 1
else:
rows += 1

self.rows = rows
self.cols = cols
num_rows += 1

scaling_factor = min(
self.screen_size[0] / (cols * subenv_size[0]),
self.screen_size[1] / (rows * subenv_size[1]),
self.screen_size[0] / (num_cols * subenv_size[0]),
self.screen_size[1] / (num_rows * subenv_size[1]),
)
assert (
num_cols * subenv_size[0] * scaling_factor == self.screen_size[0]
) or (num_rows * subenv_size[1] * scaling_factor == self.screen_size[1])

self.num_rows = num_rows
self.num_cols = num_cols
self.scaled_subenv_size = (
int(subenv_size[0] * scaling_factor),
int(subenv_size[1] * scaling_factor),
)

assert (cols * subenv_size[0] * scaling_factor == self.screen_size[0]) or (
rows * subenv_size[1] * scaling_factor == self.screen_size[1]
)

assert self.num_rows * self.num_cols >= self.num_envs
assert self.scaled_subenv_size[0] * self.num_cols <= self.screen_size[0]
assert self.scaled_subenv_size[1] * self.num_rows <= self.screen_size[1]

# print(f'{self.num_envs=}, {self.num_rows=}, {self.num_cols=}, {self.screen_size=}, {self.scaled_subenv_size=}')

try:
import cv2
except ImportError as e:
Expand All @@ -146,21 +147,17 @@ def _render_frame(self):
) from e

merged_rgb_array = np.zeros(self.screen_size + (3,), dtype=np.uint8)
i = 0
for x in np.arange(
0, self.screen_size[0], self.scaled_subenv_size[0], dtype=np.int32
):
for y in np.arange(
0, self.screen_size[1], self.scaled_subenv_size[1], dtype=np.int32
):
scaled_render = cv2.resize(
subenv_renders[i], self.scaled_subenv_size[::-1]
)
merged_rgb_array[
x : x + self.scaled_subenv_size[0],
y : y + self.scaled_subenv_size[1],
] = scaled_render
i += 1
cols, rows = np.meshgrid(np.arange(self.num_cols), np.arange(self.num_rows))

for i, col, row in zip(range(self.num_envs), cols.flatten(), rows.flatten()):
scaled_render = cv2.resize(subenv_renders[i], self.scaled_subenv_size[::-1])
x = col * self.scaled_subenv_size[0]
y = row * self.scaled_subenv_size[1]

merged_rgb_array[
x : x + self.scaled_subenv_size[0],
y : y + self.scaled_subenv_size[1],
] = scaled_render

if self.window is None:
pygame.init()
Expand Down
43 changes: 43 additions & 0 deletions tests/wrappers/vector/test_human_rendering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test suite of HumanRendering wrapper."""
import re

import pytest

import gymnasium as gym
from gymnasium.wrappers.vector import HumanRendering


@pytest.mark.parametrize("env_id", ["CartPole-v1", "Ant-v4"])
@pytest.mark.parametrize("num_envs", [1, 3, 9])
@pytest.mark.parametrize("screen_size", [None, (400, 300), (300, 600), (600, 600)])
def test_num_envs_screen_size(env_id, num_envs, screen_size):
envs = gym.make_vec(env_id, num_envs=num_envs, render_mode="rgb_array")
envs = HumanRendering(envs, screen_size=screen_size)

assert envs.render_mode == "human"

envs.reset()
for _ in range(25):
envs.step(envs.action_space.sample())
envs.close()


def test_render_modes():
envs = HumanRendering(
gym.make_vec("CartPole-v1", num_envs=3, render_mode="rgb_array_list")
)
assert envs.render_mode == "human"

envs.reset()
for _ in range(25):
envs.step(envs.action_space.sample())
envs.close()

# HumanRenderer on human renderer should not work
with pytest.raises(
AssertionError,
match=re.escape(
"Expected env.render_mode to be one of ['rgb_array', 'rgb_array_list', 'depth_array', 'depth_array_list'] but got 'human'"
),
):
HumanRendering(envs)

0 comments on commit 04fb345

Please sign in to comment.