Skip to content

Commit

Permalink
Update test_env_implementation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Nov 18, 2024
1 parent 6aa8413 commit 17e16ac
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,27 +346,22 @@ def test_cartpole_vector_equiv():
envs.close()


def test_discrete_action_validation():
env_ids = ["CarRacing-v3", "LunarLander-v3"]

for env_id in env_ids:
# get continuous action
continuous_env = gym.make(env_id, continuous=True)
continuous_action = continuous_env.action_space.sample()
continuous_env.close()

# create discrete env
discrete_env = gym.make(env_id, continuous=False).unwrapped
discrete_env.reset()

# expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander)
try:
discrete_env.step(continuous_action)
except Exception as e:
assert isinstance(e, InvalidAction) or isinstance(e, AssertionError)

# expect no error
discrete_env.reset()
discrete_action = discrete_env.action_space.sample()
discrete_env.step(discrete_action)
discrete_env.close()
@pytest.mark.parametrize("env_id", ["CarRacing-v3", "LunarLander-v3"])
def test_discrete_action_validation(env_id):
# get continuous action
continuous_env = gym.make(env_id, continuous=True)
continuous_action = continuous_env.action_space.sample()
continuous_env.close()

# create discrete env
discrete_env = gym.make(env_id, continuous=False)
discrete_env.reset()

# expect InvalidAction (caused by CarRacing) or AssertionError (caused by LunarLander)
with pytest.raises((InvalidAction, AssertionError)):
discrete_env.step(continuous_action)

# expect no error
discrete_action = discrete_env.action_space.sample()
discrete_env.step(discrete_action)
discrete_env.close()

0 comments on commit 17e16ac

Please sign in to comment.