From b09ce1a810e7c086fc61a9cf6b782e6d5dfc40d0 Mon Sep 17 00:00:00 2001 From: Roger Larsson Date: Tue, 28 Nov 2023 12:47:43 +0100 Subject: [PATCH] Add support for `JaxToNumpy` to handle NamedTuples (fixes #780) (#789) --- gymnasium/wrappers/jax_to_numpy.py | 14 ++++++++++++-- tests/wrappers/test_jax_to_numpy.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/gymnasium/wrappers/jax_to_numpy.py b/gymnasium/wrappers/jax_to_numpy.py index 807f19522..dfa682594 100644 --- a/gymnasium/wrappers/jax_to_numpy.py +++ b/gymnasium/wrappers/jax_to_numpy.py @@ -59,7 +59,12 @@ def _iterable_numpy_to_jax( value: Iterable[np.ndarray | Any], ) -> Iterable[jax.Array | Any]: """Converts an Iterable from Numpy Arrays to an iterable of Jax Array.""" - return type(value)(numpy_to_jax(v) for v in value) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(numpy_to_jax(v) for v in value) + else: + return type(value)(numpy_to_jax(v) for v in value) @functools.singledispatch @@ -89,7 +94,12 @@ def _iterable_jax_to_numpy( value: Iterable[np.ndarray | Any], ) -> Iterable[jax.Array | Any]: """Converts an Iterable from Numpy arrays to an iterable of Jax Array.""" - return type(value)(jax_to_numpy(v) for v in value) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(jax_to_numpy(v) for v in value) + else: + return type(value)(jax_to_numpy(v) for v in value) class JaxToNumpy( diff --git a/tests/wrappers/test_jax_to_numpy.py b/tests/wrappers/test_jax_to_numpy.py index 11311c66e..61422a2df 100644 --- a/tests/wrappers/test_jax_to_numpy.py +++ b/tests/wrappers/test_jax_to_numpy.py @@ -1,4 +1,5 @@ """Test suite for JaxToNumpy wrapper.""" +from typing import NamedTuple import numpy as np import pytest @@ -16,6 +17,11 @@ from tests.testing_env import GenericTestEnv # noqa: E402 +class TestingNamedTuple(NamedTuple): + a: jax.Array + b: jax.Array + + @pytest.mark.parametrize( "value, expected_value", [ @@ -55,6 +61,16 @@ "b": {"c": np.array(5, dtype=np.int32)}, }, ), + ( + TestingNamedTuple( + a=np.array([1, 2], dtype=np.int32), + b=np.array([1.0, 2.0], dtype=np.float32), + ), + TestingNamedTuple( + a=np.array([1, 2], dtype=np.int32), + b=np.array([1.0, 2.0], dtype=np.float32), + ), + ), ], ) def test_roundtripping(value, expected_value):