From da89efc41f8d86a603af61987e6bed33abf67a56 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 19:48:51 +0100 Subject: [PATCH] Remove the casts from jax to numpy in jax-based envs --- gymnasium/envs/functional_jax_env.py | 21 --------------------- gymnasium/spaces/box.py | 6 +++++- gymnasium/utils/passive_env_checker.py | 2 +- 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index 2fc5e55c7..c6cb96dce 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -13,7 +13,6 @@ from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space -from gymnasium.wrappers.jax_to_numpy import jax_to_numpy class FunctionalJaxEnv(gym.Env): @@ -69,8 +68,6 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): obs = self.func_env.observation(self.state) info = self.func_env.state_info(self.state) - obs = jax_to_numpy(obs) - return obs, info def step(self, action: ActType): @@ -92,8 +89,6 @@ def step(self, action: ActType): info = self.func_env.transition_info(self.state, action, next_state) self.state = next_state - observation = jax_to_numpy(observation) - return observation, float(reward), bool(terminated), False, info def render(self): @@ -183,20 +178,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32) - obs = jax_to_numpy(obs) - return obs, info def step(self, action: ActType): """Steps through the environment using the action.""" - if self._is_box_action_space: - assert isinstance(self.action_space, gym.spaces.Box) # For typing - action = np.clip(action, self.action_space.low, self.action_space.high) - else: # Discrete - # For now we assume jax envs don't use complex spaces - assert self.action_space.contains( - action - ), f"{action!r} ({type(action)}) invalid" self.steps += 1 rng, self.rng = jrng.split(self.rng) @@ -232,12 +217,6 @@ def step(self, action: ActType): self.autoreset_envs = done observation = self.func_env.observation(next_state) - observation = jax_to_numpy(observation) - - reward = jax_to_numpy(reward) - - terminated = jax_to_numpy(terminated) - truncated = jax_to_numpy(truncated) self.state = next_state diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 418cb7158..491036b67 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -59,6 +59,7 @@ def __init__( shape: Sequence[int] | None = None, dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32, seed: int | np.random.Generator | None = None, + require_numpy: bool = False, ): r"""Constructor of :class:`Box`. @@ -144,6 +145,8 @@ def __init__( self.low_repr = _short_repr(self.low) self.high_repr = _short_repr(self.high) + self.require_numpy = require_numpy + super().__init__(self.shape, self.dtype, seed) @property @@ -237,7 +240,8 @@ def sample(self, mask: None = None) -> NDArray[Any]: def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" if not isinstance(x, np.ndarray): - gym.logger.warn("Casting input x to numpy array.") + if self.require_numpy: + gym.logger.warn("Casting input x to numpy array.") try: x = np.asarray(x, dtype=self.dtype) except (ValueError, TypeError): diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index d9272dafe..b46859747 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -129,7 +129,7 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str): logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}") elif isinstance(observation_space, spaces.Box): if observation_space.shape != (): - if not isinstance(obs, np.ndarray): + if not isinstance(obs, np.ndarray) and observation_space.require_numpy: logger.warn( f"{pre} was expecting a numpy array, actual type: {type(obs)}" )