Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax environment return jax data rather than numpy data #817

Merged
merged 10 commits into from
Apr 5, 2024

Conversation

RedTachyon
Copy link
Member

With this PR, phys2d environments will output jax arrays instead of smushing them into numpy.

This also changes Box spaces with an additional optional argument to determine whether they should care about the element being a numpy array when doing contains. I don't think any other spaces need this as well, but I'm not 100% sure

@pseudo-rnd-thoughts
Copy link
Member

I agree that this is a good idea and I'm worried by the number of parameters we are adding but I don't see another way around implementing this

@pseudo-rnd-thoughts
Copy link
Member

I investigated how to fix this PR and the problem is the check_env function which we would need to add support for jax observations.
The problem is that functions within check_env are user facing meaning that import jax would happen everytime we weren't careful.

One solution might be to add a metadata parameter specifying that the observations are jax which changes the testing for check_env etc.
@RedTachyon Thoughts?

Copy link
Contributor

@RogerJL RogerJL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I prepared my HumanVectorRenderingPR I noticed some problems

gymnasium/spaces/multi_discrete.py:158
requires that the input array (action array to step) is numpy.ndarray...

Another might be related issue in wrappers/vector/init.py
"JaxToNumpy",
"JaxToTorch",
"NumpyToTorch",

@pseudo-rnd-thoughts pseudo-rnd-thoughts changed the title Stop casting jax env outputs to numpy Jax environment return jax data rather than numpy data Apr 5, 2024
@pseudo-rnd-thoughts
Copy link
Member

To make these changes pass the tests, I have made changes to either include the JaxToNumpy wrapper or "disable" the tests

@pseudo-rnd-thoughts pseudo-rnd-thoughts merged commit d430379 into main Apr 5, 2024
16 checks passed
@@ -367,6 +367,11 @@ def check_env(
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
)

if env.metadata.get("jax", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for all environments?
Most if environments do not have data library in metadata

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as False is the default value

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should all other environments have metadata["numpy"] = True then? (all do we expect it to be stated)

Realistically since there is no way to auto-detect it, so the solution is fine, but we should say something about in Env documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I think we can just assume that to be true so we don't need to do that.

if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
elif env.metadata.get("torch", False):
env = gym.wrappers.JaxToTorch(env)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be TorchToNumpy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, yes

@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be better to use a wrapper here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is not actually with the wrapper, it is rather that this check from removed the environments themselves.
We could re-add the checks in the environments then apply the wrapper but I decided not to do that for now

@pseudo-rnd-thoughts pseudo-rnd-thoughts deleted the jax-spaces branch May 21, 2024 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants