Skip to content

Commit

Permalink
Add shape check to MultiDiscrete __eq__ (#1044) (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenBuzz authored May 8, 2024
1 parent 17f161e commit df97b31
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
1 change: 1 addition & 0 deletions gymnasium/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __eq__(self, other: Any) -> bool:
return bool(
isinstance(other, MultiDiscrete)
and self.dtype == other.dtype
and self.shape == other.shape
and np.all(self.nvec == other.nvec)
and np.all(self.start == other.start)
)
Expand Down
20 changes: 20 additions & 0 deletions tests/spaces/test_multidiscrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ def test_multidiscrete_start_contains():
assert [13, 23, 34] not in space


def test_multidiscrete_equality():
# Check if two spaces are equivalent.
space_a = MultiDiscrete(nvec=[2, 3, 4], start=[0, 0, 1])

space_b = MultiDiscrete(nvec=[2, 3, 4], start=[0, 0, 1])
assert space_a == space_b

space_b = MultiDiscrete(nvec=[2, 4, 3], start=[0, 0, 1])
assert space_a != space_b

space_b = MultiDiscrete(nvec=[2, 3, 4], start=[1, 0, 1])
assert space_a != space_b

space_b = MultiDiscrete(nvec=[2, 3, 4], start=[0, 1, 1])
assert space_a != space_b

space_b = MultiDiscrete(nvec=[2, 3, 4, 2], start=[1, 0, 0, 0])
assert space_a != space_b


def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
# Test that start is corrected passed
Expand Down

0 comments on commit df97b31

Please sign in to comment.