Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for JaxToNumpy to handle NamedTuples (fixes #780) #789

Merged
merged 3 commits into from
Nov 28, 2023

Conversation

RogerJL
Copy link
Contributor

@RogerJL RogerJL commented Nov 24, 2023

They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy as the for loop resulted in a generator that NamedTuples could not handle.

NamedTuple is not a useful base class, every used named type needs to be registered individually. jax_to_numpy.register_namedtuple(NewNamedTuple)

[Bug Report] TypeError: .new() missing 1 required positional argument: ''

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (780)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up) - did changes that looks Wrong
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy
as the for loop resulted in a generator that NamedTuples could not handle.

NamedTuple is not a useful base class, every used named type needs to be registered individually.
jax_to_numpy.register_namedtuple(NewNamedTuple)

[Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy
as the for loop resulted in a generator that NamedTuples could not handle.

NamedTuple is not a useful base class, every used named type needs to be registered individually.
jax_to_numpy.register_namedtuple(NewNamedTuple)

[Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
@RogerJL
Copy link
Contributor Author

RogerJL commented Nov 24, 2023

Feels like register_namedtuple() should be added to documentation at a higher level, but I can not find out where...

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Roger for looking into this.
I was surprised that you would need to register the NamedTuple so I had a look into it
I found a StackOverflow that discussed a solution to this with this proposed solution

@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
    value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
    """Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
    if hasattr(value, "_make"):
        return type(value)._make(numpy_to_jax(v) for v in value)
    else:
        return type(value)(numpy_to_jax(v) for v in value)

and

@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
    value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
    """Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
    if hasattr(value, "_make"):
        return type(value)._make(jax_to_numpy(v) for v in value)
    else:
        return type(value)(jax_to_numpy(v) for v in value)

For testing, I had

class TestingNamedTuple(NamedTuple):
    a: np.ndarray
    b: np.ndarray

with adding this in the list of parameterized options

        (TestingNamedTuple(np.array([1, 2], dtype=np.int32), np.array([1.0, 2.0], dtype=np.float32)),
         TestingNamedTuple(np.array([1, 2], dtype=np.int32), np.array([1.0, 2.0], dtype=np.float32)))

Could you make these changes and remove all of the rest

They caused "missing N required positional argument: 'name', ..." in _iterable_numpy_to_jax or _iterable_jax_to_numpy
as the for loop resulted in a generator that NamedTuples could not handle.

[Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>'
@RogerJL RogerJL marked this pull request as ready for review November 27, 2023 21:55
@pseudo-rnd-thoughts pseudo-rnd-thoughts changed the title Fix (#780) jax_to_numpy did not handle NamedTuples Add support for JaxToNumpy to handle NamedTuples (fixes #780) Nov 28, 2023
Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @RogerJL for the changes, all good

@pseudo-rnd-thoughts pseudo-rnd-thoughts merged commit b09ce1a into Farama-Foundation:main Nov 28, 2023
@RogerJL RogerJL deleted the fix780 branch November 30, 2023 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants