From 48502750e4f85bbc26dc8612cb022ddbf3ce524d Mon Sep 17 00:00:00 2001 From: Sebastian Griesbach <56889221+Sebastian-Griesbach@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:49:02 +0100 Subject: [PATCH 1/3] Fix discrete CarRacing-v3 --- gymnasium/envs/box2d/car_racing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gymnasium/envs/box2d/car_racing.py b/gymnasium/envs/box2d/car_racing.py index fe3320cce..a989467e3 100644 --- a/gymnasium/envs/box2d/car_racing.py +++ b/gymnasium/envs/box2d/car_racing.py @@ -541,8 +541,8 @@ def reset( def step(self, action: Union[np.ndarray, int]): assert self.car is not None if action is not None: - action = action.astype(np.float64) if self.continuous: + action = action.astype(np.float64) self.car.steer(-action[0]) self.car.gas(action[1]) self.car.brake(action[2]) From 6aa84133c54e61f0f2a7baa7b1cd9ec2f518acb9 Mon Sep 17 00:00:00 2001 From: Sebastian Griesbach <56889221+Sebastian-Griesbach@users.noreply.github.com> Date: Sun, 17 Nov 2024 14:25:23 +0100 Subject: [PATCH 2/3] adding discrete action test --- tests/envs/test_env_implementation.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index bed8a8ab2..ef1e5967d 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -8,6 +8,7 @@ from gymnasium.envs.box2d.lunar_lander import demo_heuristic_lander from gymnasium.envs.toy_text import CliffWalkingEnv, TaxiEnv from gymnasium.envs.toy_text.frozen_lake import generate_random_map +from gymnasium.error import InvalidAction def test_lunar_lander_heuristics(): @@ -343,3 +344,29 @@ def test_cartpole_vector_equiv(): env.close() envs.close() + + +def test_discrete_action_validation(): + env_ids = ["CarRacing-v3", "LunarLander-v3"] + + for env_id in env_ids: + # get continuous action + continuous_env = gym.make(env_id, continuous=True) + continuous_action = continuous_env.action_space.sample() + continuous_env.close() + + # create discrete env + discrete_env = gym.make(env_id, continuous=False).unwrapped + discrete_env.reset() + + # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) + try: + discrete_env.step(continuous_action) + except Exception as e: + assert isinstance(e, InvalidAction) or isinstance(e, AssertionError) + + # expect no error + discrete_env.reset() + discrete_action = discrete_env.action_space.sample() + discrete_env.step(discrete_action) + discrete_env.close() From 17e16ac5bf49848c4bc81408b71f702978e9671c Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 18 Nov 2024 11:48:18 +0000 Subject: [PATCH 3/3] Update test_env_implementation.py --- tests/envs/test_env_implementation.py | 43 ++++++++++++--------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index ef1e5967d..f9ae75328 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -346,27 +346,22 @@ def test_cartpole_vector_equiv(): envs.close() -def test_discrete_action_validation(): - env_ids = ["CarRacing-v3", "LunarLander-v3"] - - for env_id in env_ids: - # get continuous action - continuous_env = gym.make(env_id, continuous=True) - continuous_action = continuous_env.action_space.sample() - continuous_env.close() - - # create discrete env - discrete_env = gym.make(env_id, continuous=False).unwrapped - discrete_env.reset() - - # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) - try: - discrete_env.step(continuous_action) - except Exception as e: - assert isinstance(e, InvalidAction) or isinstance(e, AssertionError) - - # expect no error - discrete_env.reset() - discrete_action = discrete_env.action_space.sample() - discrete_env.step(discrete_action) - discrete_env.close() +@pytest.mark.parametrize("env_id", ["CarRacing-v3", "LunarLander-v3"]) +def test_discrete_action_validation(env_id): + # get continuous action + continuous_env = gym.make(env_id, continuous=True) + continuous_action = continuous_env.action_space.sample() + continuous_env.close() + + # create discrete env + discrete_env = gym.make(env_id, continuous=False) + discrete_env.reset() + + # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) + with pytest.raises((InvalidAction, AssertionError)): + discrete_env.step(continuous_action) + + # expect no error + discrete_action = discrete_env.action_space.sample() + discrete_env.step(discrete_action) + discrete_env.close()