Skip to content

Commit

Permalink
Add wrappers.vector.HumanRendering (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Apr 23, 2024
1 parent 1b2c1ff commit d196497
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 94 deletions.
73 changes: 26 additions & 47 deletions gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def step(self, action):
self.steps_beyond_terminated = 0
if self._sutton_barto_reward:
reward = -1.0
elif not self._sutton_barto_reward:
else:
reward = 1.0
else:
if self.steps_beyond_terminated == 0:
Expand All @@ -223,7 +223,7 @@ def step(self, action):
self.steps_beyond_terminated += 1
if self._sutton_barto_reward:
reward = -1.0
elif not self._sutton_barto_reward:
else:
reward = 0.0

if self.render_mode == "human":
Expand Down Expand Up @@ -359,16 +359,16 @@ def close(self):

class CartPoleVectorEnv(VectorEnv):
metadata = {
"render_modes": ["human", "rgb_array"],
"render_modes": ["rgb_array"],
"render_fps": 50,
}

def __init__(
self,
sutton_barto_reward: bool = False,
num_envs: int = 2,
num_envs: int = 1,
max_episode_steps: int = 500,
render_mode: Optional[str] = None,
sutton_barto_reward: bool = False,
):
self._sutton_barto_reward = sutton_barto_reward

Expand All @@ -386,6 +386,8 @@ def __init__(
self.tau = 0.02 # seconds between state updates
self.kinematics_integrator = "euler"

self.state = None

self.steps = np.zeros(num_envs, dtype=np.int32)
self.prev_done = np.zeros(num_envs, dtype=np.bool_)

Expand Down Expand Up @@ -416,9 +418,7 @@ def __init__(
self.screen_width = 600
self.screen_height = 400
self.screens = None
self.clocks = None
self.isopen = True
self.state = None
self.surf = None

self.steps_beyond_terminated = None

Expand Down Expand Up @@ -469,10 +469,10 @@ def step(

truncated = self.steps >= self.max_episode_steps

if self._sutton_barto_reward is False:
reward = np.ones_like(terminated, dtype=np.float32)
elif self._sutton_barto_reward is True:
if self._sutton_barto_reward is True:
reward = -np.array(terminated, dtype=np.float32)
else:
reward = np.ones_like(terminated, dtype=np.float32)

# Reset all environments which terminated or were truncated in the last step
self.state[:, self.prev_done] = self.np_random.uniform(
Expand All @@ -485,9 +485,6 @@ def step(

self.prev_done = terminated | truncated

if self.render_mode == "human":
self.render()

return self.state.T.astype(np.float32), reward, terminated, truncated, {}

def reset(
Expand All @@ -509,8 +506,6 @@ def reset(
self.steps = np.zeros(self.num_envs, dtype=np.int32)
self.prev_done = np.zeros(self.num_envs, dtype=np.bool_)

if self.render_mode == "human":
self.render()
return self.state.T.astype(np.float32), {}

def render(self):
Expand All @@ -519,7 +514,7 @@ def render(self):
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
f'e.g. gym.make_vec("{self.spec.id}", render_mode="rgb_array")'
)
return

Expand All @@ -533,19 +528,11 @@ def render(self):

if self.screens is None:
pygame.init()
if self.render_mode == "human":
pygame.display.init()
self.screens = [
pygame.display.set_mode((self.screen_width, self.screen_height))
for _ in range(self.num_envs)
]
else: # mode == "rgb_array"
self.screens = [
pygame.Surface((self.screen_width, self.screen_height))
for _ in range(self.num_envs)
]
if self.clocks is None:
self.clock = [pygame.time.Clock() for _ in range(self.num_envs)]

self.screens = [
pygame.Surface((self.screen_width, self.screen_height))
for _ in range(self.num_envs)
]

world_width = self.x_threshold * 2
scale = self.screen_width / world_width
Expand All @@ -555,10 +542,12 @@ def render(self):
cartheight = 30.0

if self.state is None:
return None
raise ValueError(
"Cartpole's state is None, it probably hasn't be reset yet."
)

for state, screen, clock in zip(self.state, self.screens, self.clocks):
x = self.state.T
for x, screen in zip(self.state.T, self.screens):
assert isinstance(x, np.ndarray) and x.shape == (4,)

self.surf = pygame.Surface((self.screen_width, self.screen_height))
self.surf.fill((255, 255, 255))
Expand Down Expand Up @@ -607,23 +596,13 @@ def render(self):
self.surf = pygame.transform.flip(self.surf, False, True)
screen.blit(self.surf, (0, 0))

if self.render_mode == "human":
pygame.event.pump()
[clock.tick(self.metadata["render_fps"]) for clock in self.clocks]
pygame.display.flip()

elif self.render_mode == "rgb_array":
return [
np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
for screen in self.screens
]
return [
np.transpose(np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2))
for screen in self.screens
]

def close(self):
if self.screens is not None:
import pygame

pygame.display.quit()
pygame.quit()
self.isopen = False
83 changes: 51 additions & 32 deletions gymnasium/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
:func:`make_vec` is the equivalent function to :func:`make` for vector environments.
"""

# Set this in SOME subclasses
metadata: dict[str, Any] = {"render_modes": []}

metadata: dict[str, Any] = {}
spec: EnvSpec | None = None
render_mode: str | None = None
closed: bool = False
Expand Down Expand Up @@ -214,20 +212,6 @@ def close_extras(self, **kwargs: Any):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass

@property
def np_random_seed(self) -> int | None:
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
the seed will take the value -1.
Returns:
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
"""
if self._np_random_seed is None:
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random_seed

@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Expand All @@ -244,6 +228,20 @@ def np_random(self, value: np.random.Generator):
self._np_random = value
self._np_random_seed = -1

@property
def np_random_seed(self) -> int | None:
"""Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
the seed will take the value -1.
Returns:
int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
"""
if self._np_random_seed is None:
self._np_random, self._np_random_seed = seeding.np_random()
return self._np_random_seed

@property
def unwrapped(self):
"""Return the base environment."""
Expand Down Expand Up @@ -349,6 +347,7 @@ def __init__(self, env: VectorEnv):
self._action_space: gym.Space | None = None
self._single_observation_space: gym.Space | None = None
self._single_action_space: gym.Space | None = None
self._metadata: dict[str, Any] | None = None

def reset(
self,
Expand Down Expand Up @@ -386,11 +385,6 @@ def __repr__(self):
"""Return the string representation of the vectorized environment."""
return f"<{self.__class__.__name__}, {self.env}>"

@property
def spec(self) -> EnvSpec | None:
"""Gets the specification of the wrapped environment."""
return self.env.spec

@property
def observation_space(self) -> gym.Space:
"""Gets the observation space of the vector environment."""
Expand Down Expand Up @@ -444,16 +438,6 @@ def num_envs(self) -> int:
"""Gets the wrapped vector environment's num of the sub-environments."""
return self.env.num_envs

@property
def render_mode(self) -> tuple[RenderFrame, ...] | None:
"""Returns the `render_mode` from the base environment."""
return self.env.render_mode

@property
def metadata(self) -> dict[str, Any]:
"""Returns the `metadata` from the base environment."""
return self.env.metadata

@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Expand All @@ -467,6 +451,41 @@ def np_random(self) -> np.random.Generator:
def np_random(self, value: np.random.Generator):
self.env.np_random = value

@property
def np_random_seed(self) -> int | None:
"""The seeds of the vector environment's internal :attr:`_np_random`."""
return self.env.np_random_seed

@property
def metadata(self):
"""The metadata of the vector environment."""
if self._metadata is not None:
return self._metadata
return self.env.metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

@property
def spec(self) -> EnvSpec | None:
"""Gets the specification of the wrapped environment."""
return self.env.spec

@property
def render_mode(self) -> tuple[RenderFrame, ...] | None:
"""Returns the `render_mode` from the base environment."""
return self.env.render_mode

@property
def closed(self):
"""If the environment has closes."""
return self.env.closed

@closed.setter
def closed(self, value: bool):
self.env.closed = value


class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation.
Expand Down
32 changes: 19 additions & 13 deletions gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ class HumanRendering(
* v0.25.0 - Initially added
"""

ACCEPTED_RENDER_MODES = [
"rgb_array",
"rgb_array_list",
"depth_array",
"depth_array_list",
]

def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialize a :class:`HumanRendering` instance.
Expand All @@ -465,12 +472,11 @@ def __init__(self, env: gym.Env[ObsType, ActType]):
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)

assert env.render_mode in [
"rgb_array",
"rgb_array_list",
], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'"
assert (
"render_fps" in env.metadata
self.env.render_mode in self.ACCEPTED_RENDER_MODES
), f"Expected env.render_mode to be one of {self.ACCEPTED_RENDER_MODES} but got '{env.render_mode}'"
assert (
"render_fps" in self.env.metadata
), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"

self.screen_size = None
Expand Down Expand Up @@ -510,19 +516,19 @@ def _render_frame(self):
import pygame
except ImportError:
raise DependencyNotInstalled(
'pygame is not installed, run `pip install "gymnasium[box2d]"`'
'pygame is not installed, run `pip install "gymnasium[classic-control]"`'
)
if self.env.render_mode == "rgb_array_list":
assert self.env.render_mode is not None
if self.env.render_mode.endswith("_list"):
last_rgb_array = self.env.render()
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render()
else:
raise Exception(
f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}"
)
assert isinstance(last_rgb_array, np.ndarray)
last_rgb_array = self.env.render()

assert isinstance(
last_rgb_array, np.ndarray
), f"Expected `env.render()` to return a numpy array, actually returned {type(last_rgb_array)}"

rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))

Expand Down
3 changes: 2 additions & 1 deletion gymnasium/wrappers/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from gymnasium.wrappers.vector.common import RecordEpisodeStatistics
from gymnasium.wrappers.vector.dict_info_to_list import DictInfoToList
from gymnasium.wrappers.vector.rendering import HumanRendering
from gymnasium.wrappers.vector.stateful_observation import NormalizeObservation
from gymnasium.wrappers.vector.stateful_reward import NormalizeReward
from gymnasium.wrappers.vector.vectorize_action import (
Expand Down Expand Up @@ -63,7 +64,7 @@
# --- Rendering ---
# "RenderCollection",
# "RecordVideo",
# "HumanRendering",
"HumanRendering",
# --- Conversion ---
"JaxToNumpy",
"JaxToTorch",
Expand Down
Loading

0 comments on commit d196497

Please sign in to comment.