From d43037920f076376d66e1b51c345c28f0f162359 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 5 Apr 2024 18:21:10 +0200 Subject: [PATCH] Jax environment return jax data rather than numpy data (#817) Co-authored-by: pseudo-rnd-thoughts --- gymnasium/envs/__init__.py | 5 +++ gymnasium/envs/functional_jax_env.py | 37 ++-------------------- gymnasium/envs/phys2d/cartpole.py | 4 +-- gymnasium/envs/phys2d/pendulum.py | 4 +-- gymnasium/envs/tabular/blackjack.py | 2 +- gymnasium/envs/tabular/cliffwalking.py | 2 +- gymnasium/utils/env_checker.py | 5 +++ tests/envs/functional/test_jax.py | 10 +++--- tests/envs/test_action_dim_check.py | 6 ++++ tests/envs/test_envs.py | 24 +++++++++----- tests/envs/utils.py | 27 ---------------- tests/wrappers/test_passive_env_checker.py | 3 ++ 12 files changed, 48 insertions(+), 81 deletions(-) diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index d0cf5cf8a..ce95a5064 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -60,6 +60,7 @@ vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv", max_episode_steps=200, reward_threshold=195.0, + disable_env_checker=True, ) register( @@ -68,6 +69,7 @@ vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv", max_episode_steps=500, reward_threshold=475.0, + disable_env_checker=True, ) register( @@ -75,6 +77,7 @@ entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv", vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv", max_episode_steps=200, + disable_env_checker=True, ) # Box2d @@ -161,11 +164,13 @@ register( id="tabular/Blackjack-v0", entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv", + disable_env_checker=True, ) register( id="tabular/CliffWalking-v0", entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv", + disable_env_checker=True, ) diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index 2fc5e55c7..46e468936 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -6,14 +6,12 @@ import jax import jax.numpy as jnp import jax.random as jrng -import numpy as np import gymnasium as gym from gymnasium.envs.registration import EnvSpec from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space -from gymnasium.wrappers.jax_to_numpy import jax_to_numpy class FunctionalJaxEnv(gym.Env): @@ -32,7 +30,8 @@ def __init__( ): """Initialize the environment from a FuncEnv.""" if metadata is None: - metadata = {"render_mode": []} + # metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays + metadata = {"render_mode": [], "jax": True} self.func_env = func_env @@ -45,8 +44,6 @@ def __init__( self.spec = spec - self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) - if self.render_mode == "rgb_array": self.render_state = self.func_env.render_init() else: @@ -69,20 +66,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): obs = self.func_env.observation(self.state) info = self.func_env.state_info(self.state) - obs = jax_to_numpy(obs) - return obs, info def step(self, action: ActType): """Steps through the environment using the action.""" - if self._is_box_action_space: - assert isinstance(self.action_space, gym.spaces.Box) # For typing - action = np.clip(action, self.action_space.low, self.action_space.high) - else: # Discrete - # For now we assume jax envs don't use complex spaces - err_msg = f"{action!r} ({type(action)}) invalid" - assert self.action_space.contains(action), err_msg - rng, self.rng = jrng.split(self.rng) next_state = self.func_env.transition(self.state, action, rng) @@ -92,8 +79,6 @@ def step(self, action: ActType): info = self.func_env.transition_info(self.state, action, next_state) self.state = next_state - observation = jax_to_numpy(observation) - return observation, float(reward), bool(terminated), False, info def render(self): @@ -153,8 +138,6 @@ def __init__( self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_) - self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) - if self.render_mode == "rgb_array": self.render_state = self.func_env.render_init() else: @@ -183,20 +166,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32) - obs = jax_to_numpy(obs) - return obs, info def step(self, action: ActType): """Steps through the environment using the action.""" - if self._is_box_action_space: - assert isinstance(self.action_space, gym.spaces.Box) # For typing - action = np.clip(action, self.action_space.low, self.action_space.high) - else: # Discrete - # For now we assume jax envs don't use complex spaces - assert self.action_space.contains( - action - ), f"{action!r} ({type(action)}) invalid" self.steps += 1 rng, self.rng = jrng.split(self.rng) @@ -232,12 +205,6 @@ def step(self, action: ActType): self.autoreset_envs = done observation = self.func_env.observation(next_state) - observation = jax_to_numpy(observation) - - reward = jax_to_numpy(reward) - - terminated = jax_to_numpy(terminated) - truncated = jax_to_numpy(truncated) self.state = next_state diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 81e5554b2..010a702b3 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -245,7 +245,7 @@ def get_default_params(self, **kwargs) -> CartPoleParams: class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based implementation of the CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor for the CartPole where the kwargs are applied to the functional environment.""" @@ -265,7 +265,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any): class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__( self, diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index b8bfe9deb..10862c951 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -223,7 +223,7 @@ def get_default_params(self, **kwargs) -> PendulumParams: class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based pendulum environment using the functional version as base.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 30} + metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor where the kwargs are passed to the base environment to modify the parameters.""" @@ -242,7 +242,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any): class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__( self, diff --git a/gymnasium/envs/tabular/blackjack.py b/gymnasium/envs/tabular/blackjack.py index 3ebd034de..35cb4bb93 100644 --- a/gymnasium/envs/tabular/blackjack.py +++ b/gymnasium/envs/tabular/blackjack.py @@ -496,7 +496,7 @@ def get_default_params(self, **kwargs) -> BlackJackParams: class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle): """A Gymnasium Env wrapper for the functional blackjack env.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: Optional[str] = None, **kwargs): """Initializes Gym wrapper for blackjack functional env.""" diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py index fc2574b65..5b8ad7c78 100644 --- a/gymnasium/envs/tabular/cliffwalking.py +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -358,7 +358,7 @@ def render_close(self, render_state: RenderStateType) -> None: class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle): """A Gymnasium Env wrapper for the functional cliffwalking env.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} def __init__(self, render_mode: str | None = None, **kwargs): """Initializes Gym wrapper for cliffwalking functional env.""" diff --git a/gymnasium/utils/env_checker.py b/gymnasium/utils/env_checker.py index 118346565..fd05031f2 100644 --- a/gymnasium/utils/env_checker.py +++ b/gymnasium/utils/env_checker.py @@ -367,6 +367,11 @@ def check_env( f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`." ) + if env.metadata.get("jax", False): + env = gym.wrappers.JaxToNumpy(env) + elif env.metadata.get("torch", False): + env = gym.wrappers.JaxToTorch(env) + # ============= Check the spaces (observation and action) ================ if not hasattr(env, "action_space"): raise AttributeError( diff --git a/tests/envs/functional/test_jax.py b/tests/envs/functional/test_jax.py index 01b8ad198..0a109eedc 100644 --- a/tests/envs/functional/test_jax.py +++ b/tests/envs/functional/test_jax.py @@ -120,7 +120,7 @@ def test_vectorized(env_class): obs, info = env.reset(seed=0) assert obs.shape == (10,) + env.single_observation_space.shape - assert isinstance(obs, np.ndarray) + assert isinstance(obs, jax.Array) assert isinstance(info, dict) for t in range(100): @@ -128,13 +128,13 @@ def test_vectorized(env_class): obs, reward, terminated, truncated, info = env.step(action) assert obs.shape == (10,) + env.single_observation_space.shape - assert isinstance(obs, np.ndarray) + assert isinstance(obs, jax.Array) assert reward.shape == (10,) - assert isinstance(reward, np.ndarray) + assert isinstance(reward, jax.Array) assert terminated.shape == (10,) - assert isinstance(terminated, np.ndarray) + assert isinstance(terminated, jax.Array) assert truncated.shape == (10,) - assert isinstance(truncated, np.ndarray) + assert isinstance(truncated, jax.Array) assert isinstance(info, dict) # These were removed in the new autoreset order diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 0e6699baa..afe230e09 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env): Args: env (gym.Env): the gymnasium environment """ + if env.metadata.get("jax", False): + return + assert isinstance(env.action_space, spaces.Discrete) upper_bound = env.action_space.start + env.action_space.n - 1 @@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env): Args: env (gym.Env): the gymnasium environment """ + if env.metadata.get("jax", False): + return + env.reset(seed=42) assert env.spec is not None diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 0a8aa0b3a..f6b48b135 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -7,11 +7,7 @@ import gymnasium as gym from gymnasium.envs.registration import EnvSpec from gymnasium.utils.env_checker import check_env, data_equivalence -from tests.envs.utils import ( - all_testing_env_specs, - all_testing_initialised_envs, - assert_equals, -) +from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs # This runs a smoketest on each official registered env. We may want @@ -42,6 +38,7 @@ def test_all_env_api(spec): """Check that all environments pass the environment checker with no warnings other than the expected.""" with warnings.catch_warnings(record=True) as caught_warnings: env = spec.make().unwrapped + check_env(env, skip_render_check=True) env.close() @@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec): env_1 = env_spec.make(disable_env_checker=True) env_2 = env_spec.make(disable_env_checker=True) + if env_1.metadata.get("jax", False): + env_1 = gym.wrappers.JaxToNumpy(env_1) + env_2 = gym.wrappers.JaxToNumpy(env_2) + initial_obs_1, initial_info_1 = env_1.reset(seed=SEED) initial_obs_2, initial_info_2 = env_2.reset(seed=SEED) - assert_equals(initial_obs_1, initial_obs_2) + assert data_equivalence(initial_obs_1, initial_obs_2, exact=True) env_1.action_space.seed(SEED) @@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec): obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action) obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action) - assert_equals(obs_1, obs_2, f"[{time_step}] ") + assert data_equivalence( + obs_1, obs_2, exact=True + ), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}" assert env_1.observation_space.contains( obs_1 ) # obs_2 verified by previous assertion @@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec): assert ( truncated_1 == truncated_2 ), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}" - assert_equals(info_1, info_2, f"[{time_step}] ") + assert data_equivalence( + info_1, info_2, exact=True + ), f"[{time_step}] info_1={info_1}, info_2={info_2}" if ( terminated_1 or truncated_1 @@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec): ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ) def test_pickle_env(env: gym.Env): + if env.metadata.get("jax", False): + env = gym.wrappers.JaxToNumpy(env) + pickled_env = pickle.loads(pickle.dumps(env)) data_equivalence(env.reset(), pickled_env.reset()) diff --git a/tests/envs/utils.py b/tests/envs/utils.py index 1613c4918..ff8f8a783 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -1,8 +1,6 @@ """Finds all the specs that we can test with""" from typing import List, Optional -import numpy as np - import gymnasium as gym from gymnasium import logger from gymnasium.envs.registration import EnvSpec @@ -55,28 +53,3 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: for ep in ["box2d", "classic_control", "toy_text"] ) ] - - -def assert_equals(a, b, prefix=None): - """Assert equality of data structures `a` and `b`. - - Args: - a: first data structure - b: second data structure - prefix: prefix for failed assertion message for types and dicts - """ - assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}" - if isinstance(a, dict): - assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}" - - for k in a.keys(): - v_a = a[k] - v_b = b[k] - assert_equals(v_a, v_b) - elif isinstance(a, np.ndarray): - np.testing.assert_array_equal(a, b) - elif isinstance(a, tuple): - for elem_from_a, elem_from_b in zip(a, b): - assert_equals(elem_from_a, elem_from_b) - else: - assert a == b diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index 12a6c104b..02e47dfa9 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -19,6 +19,9 @@ ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ) def test_passive_checker_wrapper_warnings(env): + if env.spec is not None and env.spec.disable_env_checker: + return + with warnings.catch_warnings(record=True) as caught_warnings: checker_env = PassiveEnvChecker(env) checker_env.reset()