From e891f00d2b507ca9a42bc7fc81b4d0d004a51391 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 21:41:24 +0000 Subject: [PATCH] Move wrapper test to file with other tests --- test/terminite_illegal_bug_test.py | 71 ------------------------------ test/wrapper_test.py | 71 +++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 73 deletions(-) delete mode 100644 test/terminite_illegal_bug_test.py diff --git a/test/terminite_illegal_bug_test.py b/test/terminite_illegal_bug_test.py deleted file mode 100644 index be8f28c8e..000000000 --- a/test/terminite_illegal_bug_test.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Check for a problem with terminate illegal wrapper. - -The problem is that env variables, including agent_selection, are set by -calls from TerminateIllegalWrapper to env functions. However, they are -called by the wrapper object, not the env so they are set in the wrapper -object rather than the base env object. When the code later tries to run, -the values get updated in the env code, but the wrapper pulls it's own -values that shadow them. - -The test here confirms that is fixed. -""" - -from pettingzoo.classic import tictactoe_v3 -from pettingzoo.utils.wrappers import BaseWrapper, TerminateIllegalWrapper - - -def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: - """Run a single game with reproducible random moves.""" - assert isinstance( - env, TerminateIllegalWrapper - ), "test_terminate_illegal must use TerminateIllegalWrapper" - env.reset(seed) - for agent in env.agents: - # make the random moves reproducible - env.action_space(agent).seed(seed) - - for agent in env.agent_iter(): - _, _, termination, truncation, _ = env.last() - - if termination or truncation: - env.step(None) - else: - action = env.action_space(agent).sample() - env.step(action) - - -def test_terminate_illegal() -> None: - """Test for error in TerminateIllegalWrapper. - - A bug caused TerminateIllegalWrapper to set values on the wrapper - rather than the environment. This tests for a recurrence of that - bug. - """ - # not using env() because we need to ensure that the env is - # wrapped by TerminateIllegalWrapper - raw_env = tictactoe_v3.raw_env() - env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) - - _do_game(env, 42) - # bug is triggered by a corrupted state after a game is terminated - # due to an illegal move. So we need to run the game twice to - # see the effect. - _do_game(env, 42) - - # get a list of what all the agent_selection values in the wrapper stack - unwrapped = env - agent_selections = [] - while unwrapped != env.unwrapped: - # the actual value for this wrapper (or None if no value) - agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) - assert isinstance(unwrapped, BaseWrapper) - unwrapped = unwrapped.env - - # last one from the actual env - agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) - - # remove None from agent_selections - agent_selections = [x for x in agent_selections if x is not None] - - # all values must be the same, or else the wrapper and env are mismatched - assert len(set(agent_selections)) == 1, "agent_selection mismatch" diff --git a/test/wrapper_test.py b/test/wrapper_test.py index 650fe328b..a03bd81b3 100644 --- a/test/wrapper_test.py +++ b/test/wrapper_test.py @@ -3,8 +3,13 @@ import pytest from pettingzoo.butterfly import pistonball_v6 -from pettingzoo.classic import texas_holdem_no_limit_v6 -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv +from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3 +from pettingzoo.utils.wrappers import ( + BaseWrapper, + MultiEpisodeEnv, + MultiEpisodeParallelEnv, + TerminateIllegalWrapper, +) @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: assert ( steps == num_episodes * 125 ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." + + +def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: + """Run a single game with reproducible random moves.""" + assert isinstance( + env, TerminateIllegalWrapper + ), "test_terminate_illegal must use TerminateIllegalWrapper" + env.reset(seed) + for agent in env.agents: + # make the random moves reproducible + env.action_space(agent).seed(seed) + + for agent in env.agent_iter(): + _, _, termination, truncation, _ = env.last() + + if termination or truncation: + env.step(None) + else: + action = env.action_space(agent).sample() + env.step(action) + + +def test_terminate_illegal() -> None: + """Test for a problem with terminate illegal wrapper. + + The problem is that env variables, including agent_selection, are set by + calls from TerminateIllegalWrapper to env functions. However, they are + called by the wrapper object, not the env so they are set in the wrapper + object rather than the base env object. When the code later tries to run, + the values get updated in the env code, but the wrapper pulls it's own + values that shadow them. + + The test here confirms that is fixed. + """ + # not using env() because we need to ensure that the env is + # wrapped by TerminateIllegalWrapper + raw_env = tictactoe_v3.raw_env() + env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) + + _do_game(env, 42) + # bug is triggered by a corrupted state after a game is terminated + # due to an illegal move. So we need to run the game twice to + # see the effect. + _do_game(env, 42) + + # get a list of what all the agent_selection values in the wrapper stack + unwrapped = env + agent_selections = [] + while unwrapped != env.unwrapped: + # the actual value for this wrapper (or None if no value) + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + assert isinstance(unwrapped, BaseWrapper) + unwrapped = unwrapped.env + + # last one from the actual env + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + + # remove None from agent_selections + agent_selections = [x for x in agent_selections if x is not None] + + # all values must be the same, or else the wrapper and env are mismatched + assert len(set(agent_selections)) == 1, "agent_selection mismatch"