Skip to content

Commit

Permalink
Jax environment return jax data rather than numpy data (#817)
Browse files Browse the repository at this point in the history
Co-authored-by: pseudo-rnd-thoughts <[email protected]>
  • Loading branch information
RedTachyon and pseudo-rnd-thoughts authored Apr 5, 2024
1 parent f0202ae commit d430379
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 81 deletions.
5 changes: 5 additions & 0 deletions gymnasium/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=200,
reward_threshold=195.0,
disable_env_checker=True,
)

register(
Expand All @@ -68,13 +69,15 @@
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=500,
reward_threshold=475.0,
disable_env_checker=True,
)

register(
id="phys2d/Pendulum-v0",
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
max_episode_steps=200,
disable_env_checker=True,
)

# Box2d
Expand Down Expand Up @@ -161,11 +164,13 @@
register(
id="tabular/Blackjack-v0",
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
disable_env_checker=True,
)

register(
id="tabular/CliffWalking-v0",
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
disable_env_checker=True,
)


Expand Down
37 changes: 2 additions & 35 deletions gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np

import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy


class FunctionalJaxEnv(gym.Env):
Expand All @@ -32,7 +30,8 @@ def __init__(
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {"render_mode": []}
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
metadata = {"render_mode": [], "jax": True}

self.func_env = func_env

Expand All @@ -45,8 +44,6 @@ def __init__(

self.spec = spec

self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)

if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
Expand All @@ -69,20 +66,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None):
obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state)

obs = jax_to_numpy(obs)

return obs, info

def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg

rng, self.rng = jrng.split(self.rng)

next_state = self.func_env.transition(self.state, action, rng)
Expand All @@ -92,8 +79,6 @@ def step(self, action: ActType):
info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state

observation = jax_to_numpy(observation)

return observation, float(reward), bool(terminated), False, info

def render(self):
Expand Down Expand Up @@ -153,8 +138,6 @@ def __init__(

self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)

self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)

if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
Expand Down Expand Up @@ -183,20 +166,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None):

self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)

obs = jax_to_numpy(obs)

return obs, info

def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
assert self.action_space.contains(
action
), f"{action!r} ({type(action)}) invalid"
self.steps += 1

rng, self.rng = jrng.split(self.rng)
Expand Down Expand Up @@ -232,12 +205,6 @@ def step(self, action: ActType):
self.autoreset_envs = done

observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)

reward = jax_to_numpy(reward)

terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)

self.state = next_state

Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_default_params(self, **kwargs) -> CartPoleParams:
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
Expand All @@ -265,7 +265,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def get_default_params(self, **kwargs) -> PendulumParams:
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
Expand All @@ -242,7 +242,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def get_default_params(self, **kwargs) -> BlackJackParams:
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional blackjack env."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: Optional[str] = None, **kwargs):
"""Initializes Gym wrapper for blackjack functional env."""
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def render_close(self, render_state: RenderStateType) -> None:
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional cliffwalking env."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}

def __init__(self, render_mode: str | None = None, **kwargs):
"""Initializes Gym wrapper for cliffwalking functional env."""
Expand Down
5 changes: 5 additions & 0 deletions gymnasium/utils/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ def check_env(
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
)

if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
elif env.metadata.get("torch", False):
env = gym.wrappers.JaxToTorch(env)

# ============= Check the spaces (observation and action) ================
if not hasattr(env, "action_space"):
raise AttributeError(
Expand Down
10 changes: 5 additions & 5 deletions tests/envs/functional/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,21 @@ def test_vectorized(env_class):

obs, info = env.reset(seed=0)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(obs, jax.Array)
assert isinstance(info, dict)

for t in range(100):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)

assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(obs, jax.Array)
assert reward.shape == (10,)
assert isinstance(reward, np.ndarray)
assert isinstance(reward, jax.Array)
assert terminated.shape == (10,)
assert isinstance(terminated, np.ndarray)
assert isinstance(terminated, jax.Array)
assert truncated.shape == (10,)
assert isinstance(truncated, np.ndarray)
assert isinstance(truncated, jax.Array)
assert isinstance(info, dict)

# These were removed in the new autoreset order
Expand Down
6 changes: 6 additions & 0 deletions tests/envs/test_action_dim_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return

assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1

Expand Down Expand Up @@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return

env.reset(seed=42)

assert env.spec is not None
Expand Down
24 changes: 16 additions & 8 deletions tests/envs/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import check_env, data_equivalence
from tests.envs.utils import (
all_testing_env_specs,
all_testing_initialised_envs,
assert_equals,
)
from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs


# This runs a smoketest on each official registered env. We may want
Expand Down Expand Up @@ -42,6 +38,7 @@ def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make().unwrapped

check_env(env, skip_render_check=True)

env.close()
Expand Down Expand Up @@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)

if env_1.metadata.get("jax", False):
env_1 = gym.wrappers.JaxToNumpy(env_1)
env_2 = gym.wrappers.JaxToNumpy(env_2)

initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2)
assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)

env_1.action_space.seed(SEED)

Expand All @@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)

assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert data_equivalence(
obs_1, obs_2, exact=True
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
assert env_1.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
Expand All @@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
assert (
truncated_1 == truncated_2
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
assert_equals(info_1, info_2, f"[{time_step}] ")
assert data_equivalence(
info_1, info_2, exact=True
), f"[{time_step}] info_1={info_1}, info_2={info_2}"

if (
terminated_1 or truncated_1
Expand All @@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_pickle_env(env: gym.Env):
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)

pickled_env = pickle.loads(pickle.dumps(env))

data_equivalence(env.reset(), pickled_env.reset())
Expand Down
27 changes: 0 additions & 27 deletions tests/envs/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Finds all the specs that we can test with"""
from typing import List, Optional

import numpy as np

import gymnasium as gym
from gymnasium import logger
from gymnasium.envs.registration import EnvSpec
Expand Down Expand Up @@ -55,28 +53,3 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
for ep in ["box2d", "classic_control", "toy_text"]
)
]


def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"

for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b
3 changes: 3 additions & 0 deletions tests/wrappers/test_passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_passive_checker_wrapper_warnings(env):
if env.spec is not None and env.spec.disable_env_checker:
return

with warnings.catch_warnings(record=True) as caught_warnings:
checker_env = PassiveEnvChecker(env)
checker_env.reset()
Expand Down

0 comments on commit d430379

Please sign in to comment.