From 6ba886abce531aa7e6804a1bd18f0afadb625789 Mon Sep 17 00:00:00 2001 From: Pierre Mardon Date: Wed, 18 Jan 2023 18:32:54 +0100 Subject: [PATCH] Add explicit error messages when unflatten discrete and multidiscrete fail (#267) --- gymnasium/spaces/utils.py | 17 ++++++++++++++--- tests/spaces/test_utils.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 8d2636f42..510ba555c 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -271,7 +271,13 @@ def _unflatten_box_multibinary( @unflatten.register(Discrete) def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64: - return space.start + np.nonzero(x)[0][0] + nonzero = np.nonzero(x) + if len(nonzero[0]) == 0: + raise ValueError( + f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. " + "Not all valid samples in a flattened space can be unflattened." + ) + return space.start + nonzero[0][0] @unflatten.register(MultiDiscrete) @@ -280,8 +286,13 @@ def _unflatten_multidiscrete( ) -> NDArray[np.integer[Any]]: offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) offsets[1:] = np.cumsum(space.nvec.flatten()) - - (indices,) = cast(type(offsets[:-1]), np.nonzero(x)) + nonzero = np.nonzero(x) + if len(nonzero[0]) == 0: + raise ValueError( + f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. " + "Not all valid samples in a flattened space can be unflattened." + ) + (indices,) = cast(type(offsets[:-1]), nonzero) return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape) diff --git a/tests/spaces/test_utils.py b/tests/spaces/test_utils.py index bee2f1607..4644e0824 100644 --- a/tests/spaces/test_utils.py +++ b/tests/spaces/test_utils.py @@ -135,3 +135,15 @@ def test_flatten_roundtripping(space): for original, roundtripped in zip(samples, unflattened_samples): assert data_equivalence(original, roundtripped) + + +def test_unflatten_discrete_error(): + value = np.array([0]) + with pytest.raises(ValueError): + utils.unflatten(gym.spaces.Discrete(1), value) + + +def test_unflatten_multidiscrete_error(): + value = np.array([0, 0]) + with pytest.raises(ValueError): + utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)