From c5b5cc26e98caf6996d48df8d8b345cc157e4e7e Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 27 Feb 2024 17:53:04 +0000 Subject: [PATCH] Modify testing to add a `JaxToNumpy` wrapper --- gymnasium/envs/functional_jax_env.py | 19 +++++-------------- gymnasium/envs/phys2d/cartpole.py | 4 ++-- gymnasium/envs/phys2d/pendulum.py | 4 ++-- gymnasium/envs/tabular/blackjack.py | 2 +- gymnasium/envs/tabular/cliffwalking.py | 2 +- gymnasium/spaces/box.py | 6 +----- gymnasium/utils/passive_env_checker.py | 2 +- tests/envs/test_envs.py | 11 ++++++++++- 8 files changed, 23 insertions(+), 27 deletions(-) diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index c6cb96dce..9fc038f49 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp import jax.random as jrng -import numpy as np import gymnasium as gym from gymnasium.envs.registration import EnvSpec @@ -31,7 +30,11 @@ def __init__( ): """Initialize the environment from a FuncEnv.""" if metadata is None: - metadata = {"render_mode": []} + metadata = {"render_mode": [], "jax": True} + if "jax" not in metadata: + gym.logger.warn( + 'For environments that use Jax observations and actions (and not NumPy), specify `metadata["jax"] = True` for users' + ) self.func_env = func_env @@ -44,8 +47,6 @@ def __init__( self.spec = spec - self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) - if self.render_mode == "rgb_array": self.render_state = self.func_env.render_init() else: @@ -72,14 +73,6 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): def step(self, action: ActType): """Steps through the environment using the action.""" - if self._is_box_action_space: - assert isinstance(self.action_space, gym.spaces.Box) # For typing - action = np.clip(action, self.action_space.low, self.action_space.high) - else: # Discrete - # For now we assume jax envs don't use complex spaces - err_msg = f"{action!r} ({type(action)}) invalid" - assert self.action_space.contains(action), err_msg - rng, self.rng = jrng.split(self.rng) next_state = self.func_env.transition(self.state, action, rng) @@ -148,8 +141,6 @@ def __init__( self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_) - self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) - if self.render_mode == "rgb_array": self.render_state = self.func_env.render_init() else: diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 1aca6ee3f..b91c5b05f 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -259,7 +259,7 @@ def render_close(self, render_state: RenderStateType) -> None: class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based implementation of the CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor for the CartPole where the kwargs are applied to the functional environment.""" @@ -279,7 +279,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any): class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__( self, diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 04d55dc10..2b9544555 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -192,7 +192,7 @@ def render_close(self, render_state: RenderStateType): class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based pendulum environment using the functional version as base.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 30} + metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor where the kwargs are passed to the base environment to modify the parameters.""" @@ -211,7 +211,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any): class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__( self, diff --git a/gymnasium/envs/tabular/blackjack.py b/gymnasium/envs/tabular/blackjack.py index bfdf341dc..0b7109128 100644 --- a/gymnasium/envs/tabular/blackjack.py +++ b/gymnasium/envs/tabular/blackjack.py @@ -476,7 +476,7 @@ def render_close(self, render_state: RenderStateType) -> None: class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle): """A Gymnasium Env wrapper for the functional blackjack env.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: Optional[str] = None, **kwargs): """Initializes Gym wrapper for blackjack functional env.""" diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py index d887e75ca..d3b387a3e 100644 --- a/gymnasium/envs/tabular/cliffwalking.py +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -350,7 +350,7 @@ def render_close(self, render_state: RenderStateType) -> None: class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle): """A Gymnasium Env wrapper for the functional cliffwalking env.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs): """Initializes Gym wrapper for cliffwalking functional env.""" diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 491036b67..418cb7158 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -59,7 +59,6 @@ def __init__( shape: Sequence[int] | None = None, dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32, seed: int | np.random.Generator | None = None, - require_numpy: bool = False, ): r"""Constructor of :class:`Box`. @@ -145,8 +144,6 @@ def __init__( self.low_repr = _short_repr(self.low) self.high_repr = _short_repr(self.high) - self.require_numpy = require_numpy - super().__init__(self.shape, self.dtype, seed) @property @@ -240,8 +237,7 @@ def sample(self, mask: None = None) -> NDArray[Any]: def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" if not isinstance(x, np.ndarray): - if self.require_numpy: - gym.logger.warn("Casting input x to numpy array.") + gym.logger.warn("Casting input x to numpy array.") try: x = np.asarray(x, dtype=self.dtype) except (ValueError, TypeError): diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index b46859747..d9272dafe 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -129,7 +129,7 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str): logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}") elif isinstance(observation_space, spaces.Box): if observation_space.shape != (): - if not isinstance(obs, np.ndarray) and observation_space.require_numpy: + if not isinstance(obs, np.ndarray): logger.warn( f"{pre} was expecting a numpy array, actual type: {type(obs)}" ) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 0a8aa0b3a..6f52a7320 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -42,13 +42,22 @@ def test_all_env_api(spec): """Check that all environments pass the environment checker with no warnings other than the expected.""" with warnings.catch_warnings(record=True) as caught_warnings: env = spec.make().unwrapped + + if env.metadata.get("jax", False): + env = gym.wrappers.JaxToNumpy(env) + check_env(env, skip_render_check=True) env.close() for warning in caught_warnings: if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: - raise gym.error.Error(f"Unexpected warning: {warning.message}") + if not ( + env.metadata.get("jax", False) + and "\x1b[33mWARN: The environment (