From df97b31ebc68346424a81a4c869732c645f3e9fc Mon Sep 17 00:00:00 2001 From: Trey Wager Date: Tue, 7 May 2024 19:39:09 -0600 Subject: [PATCH] Add shape check to MultiDiscrete __eq__ (#1044) (#1045) --- gymnasium/spaces/multi_discrete.py | 1 + tests/spaces/test_multidiscrete.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/gymnasium/spaces/multi_discrete.py b/gymnasium/spaces/multi_discrete.py index caf488764..24b63c2c4 100644 --- a/gymnasium/spaces/multi_discrete.py +++ b/gymnasium/spaces/multi_discrete.py @@ -207,6 +207,7 @@ def __eq__(self, other: Any) -> bool: return bool( isinstance(other, MultiDiscrete) and self.dtype == other.dtype + and self.shape == other.shape and np.all(self.nvec == other.nvec) and np.all(self.start == other.start) ) diff --git a/tests/spaces/test_multidiscrete.py b/tests/spaces/test_multidiscrete.py index f762171e6..09668f96e 100644 --- a/tests/spaces/test_multidiscrete.py +++ b/tests/spaces/test_multidiscrete.py @@ -157,6 +157,26 @@ def test_multidiscrete_start_contains(): assert [13, 23, 34] not in space +def test_multidiscrete_equality(): + # Check if two spaces are equivalent. + space_a = MultiDiscrete(nvec=[2, 3, 4], start=[0, 0, 1]) + + space_b = MultiDiscrete(nvec=[2, 3, 4], start=[0, 0, 1]) + assert space_a == space_b + + space_b = MultiDiscrete(nvec=[2, 4, 3], start=[0, 0, 1]) + assert space_a != space_b + + space_b = MultiDiscrete(nvec=[2, 3, 4], start=[1, 0, 1]) + assert space_a != space_b + + space_b = MultiDiscrete(nvec=[2, 3, 4], start=[0, 1, 1]) + assert space_a != space_b + + space_b = MultiDiscrete(nvec=[2, 3, 4, 2], start=[1, 0, 0, 0]) + assert space_a != space_b + + def test_space_legacy_pickling(): """Test the legacy pickle of Discrete that is missing the `start` parameter.""" # Test that start is corrected passed