-
-
Notifications
You must be signed in to change notification settings - Fork 904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax environment return jax data rather than numpy data #817
Conversation
I agree that this is a good idea and I'm worried by the number of parameters we are adding but I don't see another way around implementing this |
I investigated how to fix this PR and the problem is the One solution might be to add a metadata parameter specifying that the observations are jax which changes the testing for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I prepared my HumanVectorRenderingPR I noticed some problems
gymnasium/spaces/multi_discrete.py:158
requires that the input array (action array to step) is numpy.ndarray...
Another might be related issue in wrappers/vector/init.py
"JaxToNumpy",
"JaxToTorch",
"NumpyToTorch",
To make these changes pass the tests, I have made changes to either include the |
@@ -367,6 +367,11 @@ def check_env( | |||
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`." | |||
) | |||
|
|||
if env.metadata.get("jax", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work for all environments?
Most if environments do not have data library in metadata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, as False
is the default value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should all other environments have metadata["numpy"] = True
then? (all do we expect it to be stated)
Realistically since there is no way to auto-detect it, so the solution is fine, but we should say something about in Env
documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I think we can just assume that to be true so we don't need to do that.
if env.metadata.get("jax", False): | ||
env = gym.wrappers.JaxToNumpy(env) | ||
elif env.metadata.get("torch", False): | ||
env = gym.wrappers.JaxToTorch(env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be TorchToNumpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn, yes
@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env): | |||
Args: | |||
env (gym.Env): the gymnasium environment | |||
""" | |||
if env.metadata.get("jax", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be better to use a wrapper here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is not actually with the wrapper, it is rather that this check from removed the environments themselves.
We could re-add the checks in the environments then apply the wrapper but I decided not to do that for now
With this PR, phys2d environments will output jax arrays instead of smushing them into numpy.
This also changes
Box
spaces with an additional optional argument to determine whether they should care about the element being a numpy array when doingcontains
. I don't think any other spaces need this as well, but I'm not 100% sure