diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index ef1e5967d..f9ae75328 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -346,27 +346,22 @@ def test_cartpole_vector_equiv(): envs.close() -def test_discrete_action_validation(): - env_ids = ["CarRacing-v3", "LunarLander-v3"] - - for env_id in env_ids: - # get continuous action - continuous_env = gym.make(env_id, continuous=True) - continuous_action = continuous_env.action_space.sample() - continuous_env.close() - - # create discrete env - discrete_env = gym.make(env_id, continuous=False).unwrapped - discrete_env.reset() - - # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) - try: - discrete_env.step(continuous_action) - except Exception as e: - assert isinstance(e, InvalidAction) or isinstance(e, AssertionError) - - # expect no error - discrete_env.reset() - discrete_action = discrete_env.action_space.sample() - discrete_env.step(discrete_action) - discrete_env.close() +@pytest.mark.parametrize("env_id", ["CarRacing-v3", "LunarLander-v3"]) +def test_discrete_action_validation(env_id): + # get continuous action + continuous_env = gym.make(env_id, continuous=True) + continuous_action = continuous_env.action_space.sample() + continuous_env.close() + + # create discrete env + discrete_env = gym.make(env_id, continuous=False) + discrete_env.reset() + + # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) + with pytest.raises((InvalidAction, AssertionError)): + discrete_env.step(continuous_action) + + # expect no error + discrete_action = discrete_env.action_space.sample() + discrete_env.step(discrete_action) + discrete_env.close()