From 0cdf49eeda99d07bb524278869484456deecd107 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Fri, 22 Mar 2024 10:57:02 -0400 Subject: [PATCH] Fix test to handle nested observation dicts (#1172) --- pettingzoo/test/api_test.py | 52 ++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/pettingzoo/test/api_test.py b/pettingzoo/test/api_test.py index f8718579c..558123b6c 100644 --- a/pettingzoo/test/api_test.py +++ b/pettingzoo/test/api_test.py @@ -3,6 +3,7 @@ import re import warnings from collections import defaultdict +from typing import Any import gymnasium import numpy as np @@ -383,6 +384,46 @@ def test_rewards_terminations_truncations(env, agent_0): test_reward(env.rewards[agent]) +def _test_observation_space_compatibility( + expected: gymnasium.spaces.Space[Any], + seen: gymnasium.spaces.Space[Any] | dict, + recursed_keys: list[str], +) -> None: + """Ensure observation's dtypes are same as in observation_space. + + This tests that the dtypes of the spaces are the same. + The function will recursively check observation dicts to ensure that + all components have the same dtype as declared in the observation space. + + Args: + expected: Observation space that is expected. + seen: The observation actually seen. + recursed_keys: A list of all the dict keys that led to the current + observations. This enables a more helpful error message if + an assert fails. The initial call should have an empty list. + """ + if isinstance(expected, gymnasium.spaces.Dict): + for key in expected.keys(): + if not recursed_keys and key != "observation": + # For the top level, we only care about the 'observation' key. + continue + # We know a dict is expected. Anything else is an error. + assert isinstance( + seen, dict + ), f"observation at [{']['.join(recursed_keys)}] is {seen.dtype}, but expected dict." + + # note: a previous test (expected.contains(seen)) ensures that + # the two dicts have the same keys. + _test_observation_space_compatibility( + expected[key], seen[key], recursed_keys + [key] + ) + else: + # done recursing, now the actual space types should match + assert ( + expected.dtype == seen.dtype + ), f"dtype for observation at [{']['.join(recursed_keys)}] is {seen.dtype}, but observation space specifies {expected.dtype}." + + def play_test(env, observation_0, num_cycles): """ plays through environment and does dynamic checks to make @@ -466,13 +507,10 @@ def play_test(env, observation_0, num_cycles): prev_observe ), "Out of bounds observation: " + str(prev_observe) - if isinstance(env.observation_space(agent), gymnasium.spaces.Box): - assert env.observation_space(agent).dtype == prev_observe.dtype - elif isinstance(env.observation_space(agent), gymnasium.spaces.Dict): - assert ( - env.observation_space(agent)["observation"].dtype - == prev_observe["observation"].dtype - ) + _test_observation_space_compatibility( + env.observation_space(agent), prev_observe, recursed_keys=[] + ) + test_observation(prev_observe, observation_0, str(env.unwrapped)) if not isinstance(env.infos[env.agent_selection], dict): warnings.warn(