diff --git a/gymnasium/envs/box2d/car_racing.py b/gymnasium/envs/box2d/car_racing.py index fe3320cce..a989467e3 100644 --- a/gymnasium/envs/box2d/car_racing.py +++ b/gymnasium/envs/box2d/car_racing.py @@ -541,8 +541,8 @@ def reset( def step(self, action: Union[np.ndarray, int]): assert self.car is not None if action is not None: - action = action.astype(np.float64) if self.continuous: + action = action.astype(np.float64) self.car.steer(-action[0]) self.car.gas(action[1]) self.car.brake(action[2]) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index bed8a8ab2..f9ae75328 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -8,6 +8,7 @@ from gymnasium.envs.box2d.lunar_lander import demo_heuristic_lander from gymnasium.envs.toy_text import CliffWalkingEnv, TaxiEnv from gymnasium.envs.toy_text.frozen_lake import generate_random_map +from gymnasium.error import InvalidAction def test_lunar_lander_heuristics(): @@ -343,3 +344,24 @@ def test_cartpole_vector_equiv(): env.close() envs.close() + + +@pytest.mark.parametrize("env_id", ["CarRacing-v3", "LunarLander-v3"]) +def test_discrete_action_validation(env_id): + # get continuous action + continuous_env = gym.make(env_id, continuous=True) + continuous_action = continuous_env.action_space.sample() + continuous_env.close() + + # create discrete env + discrete_env = gym.make(env_id, continuous=False) + discrete_env.reset() + + # expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander) + with pytest.raises((InvalidAction, AssertionError)): + discrete_env.step(continuous_action) + + # expect no error + discrete_action = discrete_env.action_space.sample() + discrete_env.step(discrete_action) + discrete_env.close()