From fda0bb0de597b8f86fb5fe8f33e7c18f987fb006 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 14 Mar 2024 12:46:02 -0400 Subject: [PATCH] Expand testing, refactor into`pettingzoo_api_test.py`, add pytest-xdist (performance/parallelization) (#248) --- .github/workflows/linux-test.yml | 1 + pyproject.toml | 3 +- test/pettingzoo_api_test.py | 123 ++++++++++++++++++++----------- 3 files changed, 85 insertions(+), 42 deletions(-) diff --git a/.github/workflows/linux-test.yml b/.github/workflows/linux-test.yml index 86b543d..b0c6367 100644 --- a/.github/workflows/linux-test.yml +++ b/.github/workflows/linux-test.yml @@ -32,6 +32,7 @@ jobs: - name: Install python dependencies run: | pip install -e .[testing] + AutoROM -v - name: Test with pytest run: | xvfb-run -s "-screen 0 1400x900x24" pytest ./test diff --git a/pyproject.toml b/pyproject.toml index 7fc1ca8..f84c314 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dynamic = ["version"] [project.optional-dependencies] # Update dependencies in `all` if any are added or removed -testing = ["pettingzoo[butterfly,classic]>=1.23.1", "pytest", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"] +testing = ["pettingzoo[all,atari]>=1.23.1", "AutoROM", "pytest", "pytest-xdist", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"] [project.urls] Homepage = "https://farama.org" @@ -84,3 +84,4 @@ reportUnboundVariable = "warning" [tool.pytest.ini_options] filterwarnings = ["ignore::DeprecationWarning:gymnasium.*:"] +addopts = [ "-n=auto" ] diff --git a/test/pettingzoo_api_test.py b/test/pettingzoo_api_test.py index 20878a5..572a652 100644 --- a/test/pettingzoo_api_test.py +++ b/test/pettingzoo_api_test.py @@ -1,9 +1,21 @@ import numpy as np import pytest -from pettingzoo.butterfly import knights_archers_zombies_v10 +from pettingzoo.butterfly import ( + cooperative_pong_v5, + knights_archers_zombies_v10, + pistonball_v6, +) from pettingzoo.classic import connect_four_v3 -from pettingzoo.mpe import simple_push_v3, simple_world_comm_v3 +from pettingzoo.mpe import simple_push_v3, simple_spread_v3, simple_world_comm_v3 +from pettingzoo.sisl import pursuit_v4 from pettingzoo.test import api_test, parallel_api_test, seed_test +from pettingzoo.utils.all_modules import ( + atari_environments, + butterfly_environments, + classic_environments, + mpe_environments, + sisl_environments, +) import supersuit from supersuit import ( @@ -16,18 +28,30 @@ from supersuit.utils.convert_box import convert_box -BUTTERFLY_MPE_CLASSIC = [knights_archers_zombies_v10, simple_push_v3, connect_four_v3] -BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3] +atari = list(atari_environments.values()) +butterfly = list(butterfly_environments.values()) +classic = list(classic_environments.values()) +mpe = list(mpe_environments.values()) +sisl = list(sisl_environments.values()) +all = atari + butterfly + classic + mpe + sisl + +BUTTERFLY_MPE_CLASSIC = [ + knights_archers_zombies_v10, + simple_push_v3, + connect_four_v3, + simple_spread_v3, +] +BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3, simple_spread_v3] -@pytest.mark.parametrize("env_fn", [simple_push_v3, simple_world_comm_v3]) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_frame_stack(env_fn): _env = env_fn.env() wrapped_env = frame_stack_v2(_env) api_test(wrapped_env) -@pytest.mark.parametrize("env_fn", [simple_push_v3, simple_world_comm_v3]) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_frame_stack_parallel(env_fn): _env = env_fn.parallel_env() wrapped_env = frame_stack_v2(_env) @@ -60,30 +84,34 @@ def test_frame_skip_parallel(env_fn): x += env.num_agents -@pytest.mark.parametrize("env_fn", [simple_world_comm_v3, connect_four_v3]) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_pad_action_space(env_fn): - _env = simple_world_comm_v3.env() + _env = env_fn.env() wrapped_env = pad_action_space_v0(_env) api_test(wrapped_env) seed_test(lambda: sticky_actions_v0(simple_world_comm_v3.env(), 0.5), 100) -@pytest.mark.parametrize("env_fn", [simple_world_comm_v3]) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_pad_action_space_parallel(env_fn): _env = env_fn.parallel_env() wrapped_env = pad_action_space_v0(_env) parallel_api_test(wrapped_env) -@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) +@pytest.mark.parametrize( + "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4] +) def test_color_reduction(env_fn): - env = supersuit.color_reduction_v0(env_fn.env(vector_state=False), "R") + env = supersuit.color_reduction_v0(env_fn.env(), "R") api_test(env) -@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) +@pytest.mark.parametrize( + "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4] +) def test_color_reduction_parallel(env_fn): - env = supersuit.color_reduction_v0(env_fn.parallel_env(vector_state=False), "R") + env = supersuit.color_reduction_v0(env_fn.parallel_env(), "R") parallel_api_test(env) @@ -111,26 +139,31 @@ def test_resize_dtype_parallel(env_fn, wrapper_kwargs): parallel_api_test(env) -@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) +@pytest.mark.parametrize( + "env_fn", + atari + + butterfly + + [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"], +) def test_dtype(env_fn): env = supersuit.dtype_v0(env_fn.env(), np.int32) api_test(env) -@pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) +@pytest.mark.parametrize("env_fn", atari + butterfly + sisl) def test_dtype_parallel(env_fn): env = supersuit.dtype_v0(env_fn.parallel_env(), np.int32) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_flatten(env_fn): env = supersuit.flatten_v0(knights_archers_zombies_v10.env()) api_test(env) # Classic environments don't have parallel envs so this doesn't apply -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_flatten_parallel(env_fn): env = supersuit.flatten_v0(env_fn.parallel_env()) parallel_api_test(env) @@ -165,28 +198,28 @@ def test_normalize_obs_parallel(env_fn): parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_pad_observations(env_fn): - env = supersuit.pad_observations_v0(simple_world_comm_v3.env()) + env = supersuit.pad_observations_v0(env_fn.env()) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_pad_observations_parallel(env_fn): - env = supersuit.pad_observations_v0(simple_world_comm_v3.parallel_env()) + env = supersuit.pad_observations_v0(env_fn.parallel_env()) parallel_api_test(env) @pytest.mark.skip( reason="Black death wrapper is only designed for parallel envs, AEC envs should simply skip the agent by setting env.agent_selection manually" ) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_black_death(env_fn): env = supersuit.black_death_v3(env_fn.env()) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_black_death_parallel(env_fn): env = supersuit.black_death_v3(env_fn.parallel_env()) parallel_api_test(env) @@ -206,19 +239,24 @@ def test_agent_indicator_parallel(env_fn, env_kwargs): parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_reward_lambda(env_fn): env = supersuit.reward_lambda_v0(env_fn.env(), lambda x: x / 10) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_reward_lambda_parallel(env_fn): env = supersuit.reward_lambda_v0(env_fn.parallel_env(), lambda x: x / 10) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize( + "env_fn", + [v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"] + + mpe + + sisl, +) def test_observation_lambda(env_fn): env = supersuit.observation_lambda_v0(env_fn.env(), lambda obs, obs_space: obs - 1) api_test(env) @@ -270,90 +308,93 @@ def change_observation_fn(observation, old_obs_space): api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_observation_lambda_parallel(env_fn): pass -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_clip_reward(env_fn): env = supersuit.clip_reward_v0(env_fn.env()) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_clip_reward_parallel(env_fn): env = supersuit.clip_reward_v0(env_fn.parallel_env()) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_nan_noop(env_fn): env = supersuit.nan_noop_v0(env_fn.env(), 0) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_nan_noop_parallel(env_fn): env = supersuit.nan_noop_v0(env_fn.parallel_env(), 0) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_nan_zeros(env_fn): env = supersuit.nan_zeros_v0(env_fn.env()) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_nan_zeros_parallel(env_fn): env = supersuit.nan_zeros_v0(env_fn.parallel_env()) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +# Note: hanabi v5 fails here +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_nan_random(env_fn): env = supersuit.nan_random_v0(env_fn.env()) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_nan_random_parallel(env_fn): env = supersuit.nan_random_v0(env_fn.parallel_env()) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +# Note: hanabi v5 fails here +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_sticky_actions(env_fn): env = supersuit.sticky_actions_v0(env_fn.env(), 0.75) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_sticky_actions_parallel(env_fn): env = supersuit.sticky_actions_v0(env_fn.parallel_env(), 0.75) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", [connect_four_v3]) +# Note: hanabi_v5 and texas_holdem_v4 fail here +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_delay_observations(env_fn): env = supersuit.delay_observations_v0(env_fn.env(), 3) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_delay_observations_parallel(env_fn): env = supersuit.delay_observations_v0(env_fn.parallel_env(), 3) parallel_api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE_CLASSIC) +@pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) def test_max_observation(env_fn): env = supersuit.max_observation_v0(knights_archers_zombies_v10.env(), 3) api_test(env) -@pytest.mark.parametrize("env_fn", BUTTERFLY_MPE) +@pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) def test_max_observation_parallel(env_fn): env = supersuit.max_observation_v0(knights_archers_zombies_v10.parallel_env(), 3) parallel_api_test(env)