Skip to content

Commit

Permalink
Expand testing, refactor intopettingzoo_api_test.py, add pytest-xdi…
Browse files Browse the repository at this point in the history
…st (performance/parallelization) (#248)
  • Loading branch information
elliottower authored Mar 14, 2024
1 parent ed2641e commit fda0bb0
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/workflows/linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Install python dependencies
run: |
pip install -e .[testing]
AutoROM -v
- name: Test with pytest
run: |
xvfb-run -s "-screen 0 1400x900x24" pytest ./test
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dynamic = ["version"]

[project.optional-dependencies]
# Update dependencies in `all` if any are added or removed
testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"]
testing = ["pettingzoo[all,atari]>=1.23.1", "AutoROM", "pytest", "pytest-xdist", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"]

[project.urls]
Homepage = "https://farama.org"
Expand Down Expand Up @@ -84,3 +84,4 @@ reportUnboundVariable = "warning"

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning:gymnasium.*:"]
addopts = [ "-n=auto" ]
123 changes: 82 additions & 41 deletions test/pettingzoo_api_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import numpy as np
import pytest
from pettingzoo.butterfly import knights_archers_zombies_v10
from pettingzoo.butterfly import (
cooperative_pong_v5,
knights_archers_zombies_v10,
pistonball_v6,
)
from pettingzoo.classic import connect_four_v3
from pettingzoo.mpe import simple_push_v3, simple_world_comm_v3
from pettingzoo.mpe import simple_push_v3, simple_spread_v3, simple_world_comm_v3
from pettingzoo.sisl import pursuit_v4
from pettingzoo.test import api_test, parallel_api_test, seed_test
from pettingzoo.utils.all_modules import (
atari_environments,
butterfly_environments,
classic_environments,
mpe_environments,
sisl_environments,
)

import supersuit
from supersuit import (
Expand All @@ -16,18 +28,30 @@
from supersuit.utils.convert_box import convert_box


BUTTERFLY_MPE_CLASSIC = [knights_archers_zombies_v10, simple_push_v3, connect_four_v3]
BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3]
atari = list(atari_environments.values())
butterfly = list(butterfly_environments.values())
classic = list(classic_environments.values())
mpe = list(mpe_environments.values())
sisl = list(sisl_environments.values())
all = atari + butterfly + classic + mpe + sisl

BUTTERFLY_MPE_CLASSIC = [
knights_archers_zombies_v10,
simple_push_v3,
connect_four_v3,
simple_spread_v3,
]
BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3, simple_spread_v3]


@pytest.mark.parametrize("env_fn", [simple_push_v3, simple_world_comm_v3])
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_frame_stack(env_fn):
_env = env_fn.env()
wrapped_env = frame_stack_v2(_env)
api_test(wrapped_env)


@pytest.mark.parametrize("env_fn", [simple_push_v3, simple_world_comm_v3])
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_frame_stack_parallel(env_fn):
_env = env_fn.parallel_env()
wrapped_env = frame_stack_v2(_env)
Expand Down Expand Up @@ -60,30 +84,34 @@ def test_frame_skip_parallel(env_fn):
x += env.num_agents


@pytest.mark.parametrize("env_fn", [simple_world_comm_v3, connect_four_v3])
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_pad_action_space(env_fn):
_env = simple_world_comm_v3.env()
_env = env_fn.env()
wrapped_env = pad_action_space_v0(_env)
api_test(wrapped_env)
seed_test(lambda: sticky_actions_v0(simple_world_comm_v3.env(), 0.5), 100)


@pytest.mark.parametrize("env_fn", [simple_world_comm_v3])
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_pad_action_space_parallel(env_fn):
_env = env_fn.parallel_env()
wrapped_env = pad_action_space_v0(_env)
parallel_api_test(wrapped_env)


@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
@pytest.mark.parametrize(
"env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
)
def test_color_reduction(env_fn):
env = supersuit.color_reduction_v0(env_fn.env(vector_state=False), "R")
env = supersuit.color_reduction_v0(env_fn.env(), "R")
api_test(env)


@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
@pytest.mark.parametrize(
"env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
)
def test_color_reduction_parallel(env_fn):
env = supersuit.color_reduction_v0(env_fn.parallel_env(vector_state=False), "R")
env = supersuit.color_reduction_v0(env_fn.parallel_env(), "R")
parallel_api_test(env)


Expand Down Expand Up @@ -111,26 +139,31 @@ def test_resize_dtype_parallel(env_fn, wrapper_kwargs):
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
@pytest.mark.parametrize(
"env_fn",
atari
+ butterfly
+ [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"],
)
def test_dtype(env_fn):
env = supersuit.dtype_v0(env_fn.env(), np.int32)
api_test(env)


@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
@pytest.mark.parametrize("env_fn", atari + butterfly + sisl)
def test_dtype_parallel(env_fn):
env = supersuit.dtype_v0(env_fn.parallel_env(), np.int32)
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_flatten(env_fn):
env = supersuit.flatten_v0(knights_archers_zombies_v10.env())
api_test(env)


# Classic environments don't have parallel envs so this doesn't apply
@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_flatten_parallel(env_fn):
env = supersuit.flatten_v0(env_fn.parallel_env())
parallel_api_test(env)
Expand Down Expand Up @@ -165,28 +198,28 @@ def test_normalize_obs_parallel(env_fn):
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_pad_observations(env_fn):
env = supersuit.pad_observations_v0(simple_world_comm_v3.env())
env = supersuit.pad_observations_v0(env_fn.env())
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_pad_observations_parallel(env_fn):
env = supersuit.pad_observations_v0(simple_world_comm_v3.parallel_env())
env = supersuit.pad_observations_v0(env_fn.parallel_env())
parallel_api_test(env)


@pytest.mark.skip(
reason="Black death wrapper is only designed for parallel envs, AEC envs should simply skip the agent by setting env.agent_selection manually"
)
@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_black_death(env_fn):
env = supersuit.black_death_v3(env_fn.env())
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_black_death_parallel(env_fn):
env = supersuit.black_death_v3(env_fn.parallel_env())
parallel_api_test(env)
Expand All @@ -206,19 +239,24 @@ def test_agent_indicator_parallel(env_fn, env_kwargs):
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_reward_lambda(env_fn):
env = supersuit.reward_lambda_v0(env_fn.env(), lambda x: x / 10)
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_reward_lambda_parallel(env_fn):
env = supersuit.reward_lambda_v0(env_fn.parallel_env(), lambda x: x / 10)
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize(
"env_fn",
[v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"]
+ mpe
+ sisl,
)
def test_observation_lambda(env_fn):
env = supersuit.observation_lambda_v0(env_fn.env(), lambda obs, obs_space: obs - 1)
api_test(env)
Expand Down Expand Up @@ -270,90 +308,93 @@ def change_observation_fn(observation, old_obs_space):
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_observation_lambda_parallel(env_fn):
pass


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_clip_reward(env_fn):
env = supersuit.clip_reward_v0(env_fn.env())
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_clip_reward_parallel(env_fn):
env = supersuit.clip_reward_v0(env_fn.parallel_env())
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_nan_noop(env_fn):
env = supersuit.nan_noop_v0(env_fn.env(), 0)
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_nan_noop_parallel(env_fn):
env = supersuit.nan_noop_v0(env_fn.parallel_env(), 0)
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_nan_zeros(env_fn):
env = supersuit.nan_zeros_v0(env_fn.env())
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_nan_zeros_parallel(env_fn):
env = supersuit.nan_zeros_v0(env_fn.parallel_env())
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
# Note: hanabi v5 fails here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_nan_random(env_fn):
env = supersuit.nan_random_v0(env_fn.env())
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_nan_random_parallel(env_fn):
env = supersuit.nan_random_v0(env_fn.parallel_env())
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
# Note: hanabi v5 fails here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_sticky_actions(env_fn):
env = supersuit.sticky_actions_v0(env_fn.env(), 0.75)
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_sticky_actions_parallel(env_fn):
env = supersuit.sticky_actions_v0(env_fn.parallel_env(), 0.75)
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", [connect_four_v3])
# Note: hanabi_v5 and texas_holdem_v4 fail here
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_delay_observations(env_fn):
env = supersuit.delay_observations_v0(env_fn.env(), 3)
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_delay_observations_parallel(env_fn):
env = supersuit.delay_observations_v0(env_fn.parallel_env(), 3)
parallel_api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC)
@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
def test_max_observation(env_fn):
env = supersuit.max_observation_v0(knights_archers_zombies_v10.env(), 3)
api_test(env)


@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE)
@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
def test_max_observation_parallel(env_fn):
env = supersuit.max_observation_v0(knights_archers_zombies_v10.parallel_env(), 3)
parallel_api_test(env)

0 comments on commit fda0bb0

Please sign in to comment.