-
-
Notifications
You must be signed in to change notification settings - Fork 946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for JaxToNumpy
to handle NamedTuples (fixes #780)
#789
Conversation
They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy as the for loop resulted in a generator that NamedTuples could not handle. NamedTuple is not a useful base class, every used named type needs to be registered individually. jax_to_numpy.register_namedtuple(NewNamedTuple) [Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy as the for loop resulted in a generator that NamedTuples could not handle. NamedTuple is not a useful base class, every used named type needs to be registered individually. jax_to_numpy.register_namedtuple(NewNamedTuple) [Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
Feels like register_namedtuple() should be added to documentation at a higher level, but I can not find out where... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Roger for looking into this.
I was surprised that you would need to register the NamedTuple so I had a look into it
I found a StackOverflow that discussed a solution to this with this proposed solution
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
return type(value)._make(numpy_to_jax(v) for v in value)
else:
return type(value)(numpy_to_jax(v) for v in value)
and
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
return type(value)._make(jax_to_numpy(v) for v in value)
else:
return type(value)(jax_to_numpy(v) for v in value)
For testing, I had
class TestingNamedTuple(NamedTuple):
a: np.ndarray
b: np.ndarray
with adding this in the list of parameterized options
(TestingNamedTuple(np.array([1, 2], dtype=np.int32), np.array([1.0, 2.0], dtype=np.float32)),
TestingNamedTuple(np.array([1, 2], dtype=np.int32), np.array([1.0, 2.0], dtype=np.float32)))
Could you make these changes and remove all of the rest
They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy as the for loop resulted in a generator that NamedTuples could not handle. [Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
JaxToNumpy
to handle NamedTuples (fixes #780)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @RogerJL for the changes, all good
They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy as the for loop resulted in a generator that NamedTuples could not handle.
NamedTuple is not a useful base class, every used named type needs to be registered individually. jax_to_numpy.register_namedtuple(NewNamedTuple)
[Bug Report] TypeError: .new() missing 1 required positional argument: ''
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (780)
Type of change
Please delete options that are not relevant.
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up) - did changes that looks Wrong