Skip to content

Commit

Permalink
Move wrapper test to file with other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman committed May 7, 2024
1 parent 63543af commit e891f00
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 73 deletions.
71 changes: 0 additions & 71 deletions test/terminite_illegal_bug_test.py

This file was deleted.

71 changes: 69 additions & 2 deletions test/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import pytest

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv
from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3
from pettingzoo.utils.wrappers import (
BaseWrapper,
MultiEpisodeEnv,
MultiEpisodeParallelEnv,
TerminateIllegalWrapper,
)


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
Expand Down Expand Up @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."


def _do_game(env: TerminateIllegalWrapper, seed: int) -> None:
"""Run a single game with reproducible random moves."""
assert isinstance(
env, TerminateIllegalWrapper
), "test_terminate_illegal must use TerminateIllegalWrapper"
env.reset(seed)
for agent in env.agents:
# make the random moves reproducible
env.action_space(agent).seed(seed)

for agent in env.agent_iter():
_, _, termination, truncation, _ = env.last()

if termination or truncation:
env.step(None)
else:
action = env.action_space(agent).sample()
env.step(action)


def test_terminate_illegal() -> None:
"""Test for a problem with terminate illegal wrapper.
The problem is that env variables, including agent_selection, are set by
calls from TerminateIllegalWrapper to env functions. However, they are
called by the wrapper object, not the env so they are set in the wrapper
object rather than the base env object. When the code later tries to run,
the values get updated in the env code, but the wrapper pulls it's own
values that shadow them.
The test here confirms that is fixed.
"""
# not using env() because we need to ensure that the env is
# wrapped by TerminateIllegalWrapper
raw_env = tictactoe_v3.raw_env()
env = TerminateIllegalWrapper(raw_env, illegal_reward=-1)

_do_game(env, 42)
# bug is triggered by a corrupted state after a game is terminated
# due to an illegal move. So we need to run the game twice to
# see the effect.
_do_game(env, 42)

# get a list of what all the agent_selection values in the wrapper stack
unwrapped = env
agent_selections = []
while unwrapped != env.unwrapped:
# the actual value for this wrapper (or None if no value)
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))
assert isinstance(unwrapped, BaseWrapper)
unwrapped = unwrapped.env

# last one from the actual env
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))

# remove None from agent_selections
agent_selections = [x for x in agent_selections if x is not None]

# all values must be the same, or else the wrapper and env are mismatched
assert len(set(agent_selections)) == 1, "agent_selection mismatch"

0 comments on commit e891f00

Please sign in to comment.