Skip to content

Commit

Permalink
Add explicit error messages when unflatten discrete and multidiscrete…
Browse files Browse the repository at this point in the history
… fail (openai#267)
  • Loading branch information
PierreMardon authored Jan 18, 2023
1 parent bb368fe commit 6ba886a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
17 changes: 14 additions & 3 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,13 @@ def _unflatten_box_multibinary(

@unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
return space.start + np.nonzero(x)[0][0]
nonzero = np.nonzero(x)
if len(nonzero[0]) == 0:
raise ValueError(
f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. "
"Not all valid samples in a flattened space can be unflattened."
)
return space.start + nonzero[0][0]


@unflatten.register(MultiDiscrete)
Expand All @@ -280,8 +286,13 @@ def _unflatten_multidiscrete(
) -> NDArray[np.integer[Any]]:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())

(indices,) = cast(type(offsets[:-1]), np.nonzero(x))
nonzero = np.nonzero(x)
if len(nonzero[0]) == 0:
raise ValueError(
f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. "
"Not all valid samples in a flattened space can be unflattened."
)
(indices,) = cast(type(offsets[:-1]), nonzero)
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)


Expand Down
12 changes: 12 additions & 0 deletions tests/spaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,15 @@ def test_flatten_roundtripping(space):

for original, roundtripped in zip(samples, unflattened_samples):
assert data_equivalence(original, roundtripped)


def test_unflatten_discrete_error():
value = np.array([0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.Discrete(1), value)


def test_unflatten_multidiscrete_error():
value = np.array([0, 0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)

0 comments on commit 6ba886a

Please sign in to comment.