Skip to content

Commit

Permalink
Remove the casts from jax to numpy in jax-based envs
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Dec 5, 2023
1 parent 2790321 commit da89efc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 23 deletions.
21 changes: 0 additions & 21 deletions gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy


class FunctionalJaxEnv(gym.Env):
Expand Down Expand Up @@ -69,8 +68,6 @@ def reset(self, *, seed: int | None = None, options: dict | None = None):
obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state)

obs = jax_to_numpy(obs)

return obs, info

def step(self, action: ActType):
Expand All @@ -92,8 +89,6 @@ def step(self, action: ActType):
info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state

observation = jax_to_numpy(observation)

return observation, float(reward), bool(terminated), False, info

def render(self):
Expand Down Expand Up @@ -183,20 +178,10 @@ def reset(self, *, seed: int | None = None, options: dict | None = None):

self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)

obs = jax_to_numpy(obs)

return obs, info

def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
assert self.action_space.contains(
action
), f"{action!r} ({type(action)}) invalid"
self.steps += 1

rng, self.rng = jrng.split(self.rng)
Expand Down Expand Up @@ -232,12 +217,6 @@ def step(self, action: ActType):
self.autoreset_envs = done

observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)

reward = jax_to_numpy(reward)

terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)

self.state = next_state

Expand Down
6 changes: 5 additions & 1 deletion gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
require_numpy: bool = False,
):
r"""Constructor of :class:`Box`.
Expand Down Expand Up @@ -144,6 +145,8 @@ def __init__(
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)

self.require_numpy = require_numpy

super().__init__(self.shape, self.dtype, seed)

@property
Expand Down Expand Up @@ -237,7 +240,8 @@ def sample(self, mask: None = None) -> NDArray[Any]:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, np.ndarray):
gym.logger.warn("Casting input x to numpy array.")
if self.require_numpy:
gym.logger.warn("Casting input x to numpy array.")
try:
x = np.asarray(x, dtype=self.dtype)
except (ValueError, TypeError):
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str):
logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
elif isinstance(observation_space, spaces.Box):
if observation_space.shape != ():
if not isinstance(obs, np.ndarray):
if not isinstance(obs, np.ndarray) and observation_space.require_numpy:
logger.warn(
f"{pre} was expecting a numpy array, actual type: {type(obs)}"
)
Expand Down

0 comments on commit da89efc

Please sign in to comment.