Skip to content

Commit

Permalink
Add support for JaxToNumpy to handle NamedTuples (fixes #780) (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerJL authored Nov 28, 2023
1 parent f4c302d commit b09ce1a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
14 changes: 12 additions & 2 deletions gymnasium/wrappers/jax_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
return type(value)(numpy_to_jax(v) for v in value)
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_jax(v) for v in value)
else:
return type(value)(numpy_to_jax(v) for v in value)


@functools.singledispatch
Expand Down Expand Up @@ -89,7 +94,12 @@ def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
return type(value)(jax_to_numpy(v) for v in value)
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_numpy(v) for v in value)
else:
return type(value)(jax_to_numpy(v) for v in value)


class JaxToNumpy(
Expand Down
16 changes: 16 additions & 0 deletions tests/wrappers/test_jax_to_numpy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test suite for JaxToNumpy wrapper."""
from typing import NamedTuple

import numpy as np
import pytest
Expand All @@ -16,6 +17,11 @@
from tests.testing_env import GenericTestEnv # noqa: E402


class TestingNamedTuple(NamedTuple):
a: jax.Array
b: jax.Array


@pytest.mark.parametrize(
"value, expected_value",
[
Expand Down Expand Up @@ -55,6 +61,16 @@
"b": {"c": np.array(5, dtype=np.int32)},
},
),
(
TestingNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
TestingNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
),
],
)
def test_roundtripping(value, expected_value):
Expand Down

0 comments on commit b09ce1a

Please sign in to comment.