Skip to content

Commit

Permalink
Modify testing to add a JaxToNumpy wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Feb 27, 2024
1 parent 0145ad2 commit c5b5cc2
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 27 deletions.
19 changes: 5 additions & 14 deletions gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np

import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
Expand All @@ -31,7 +30,11 @@ def __init__(
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {"render_mode": []}
metadata = {"render_mode": [], "jax": True}
if "jax" not in metadata:
gym.logger.warn(
'For environments that use Jax observations and actions (and not NumPy), specify `metadata["jax"] = True` for users'
)

self.func_env = func_env

Expand All @@ -44,8 +47,6 @@ def __init__(

self.spec = spec

self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)

if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
Expand All @@ -72,14 +73,6 @@ def reset(self, *, seed: int | None = None, options: dict | None = None):

def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg

rng, self.rng = jrng.split(self.rng)

next_state = self.func_env.transition(self.state, action, rng)
Expand Down Expand Up @@ -148,8 +141,6 @@ def __init__(

self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)

self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)

if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def render_close(self, render_state: RenderStateType) -> None:
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
Expand All @@ -279,7 +279,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def render_close(self, render_state: RenderStateType):
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
Expand All @@ -211,7 +211,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def render_close(self, render_state: RenderStateType) -> None:
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional blackjack env."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: Optional[str] = None, **kwargs):
"""Initializes Gym wrapper for blackjack functional env."""
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def render_close(self, render_state: RenderStateType) -> None:
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional cliffwalking env."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs):
"""Initializes Gym wrapper for cliffwalking functional env."""
Expand Down
6 changes: 1 addition & 5 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
require_numpy: bool = False,
):
r"""Constructor of :class:`Box`.
Expand Down Expand Up @@ -145,8 +144,6 @@ def __init__(
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)

self.require_numpy = require_numpy

super().__init__(self.shape, self.dtype, seed)

@property
Expand Down Expand Up @@ -240,8 +237,7 @@ def sample(self, mask: None = None) -> NDArray[Any]:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, np.ndarray):
if self.require_numpy:
gym.logger.warn("Casting input x to numpy array.")
gym.logger.warn("Casting input x to numpy array.")
try:
x = np.asarray(x, dtype=self.dtype)
except (ValueError, TypeError):
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str):
logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
elif isinstance(observation_space, spaces.Box):
if observation_space.shape != ():
if not isinstance(obs, np.ndarray) and observation_space.require_numpy:
if not isinstance(obs, np.ndarray):
logger.warn(
f"{pre} was expecting a numpy array, actual type: {type(obs)}"
)
Expand Down
11 changes: 10 additions & 1 deletion tests/envs/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make().unwrapped

if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)

check_env(env, skip_render_check=True)

env.close()

for warning in caught_warnings:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
if not (
env.metadata.get("jax", False)
and "\x1b[33mWARN: The environment (<JaxToNumpy<"
in warning.message.args[0]
):
raise gym.error.Error(f"Unexpected warning: {warning.message}")


@pytest.mark.parametrize(
Expand Down

0 comments on commit c5b5cc2

Please sign in to comment.