diff --git a/gymnasium/wrappers/common.py b/gymnasium/wrappers/common.py index e66f6f7af..9b3d225bd 100644 --- a/gymnasium/wrappers/common.py +++ b/gymnasium/wrappers/common.py @@ -96,8 +96,11 @@ def __init__( Args: env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) + max_episode_steps: the environment step after which the episode is truncated (``elapsed >= max_episode_steps``) """ + assert ( + isinstance(max_episode_steps, int) and max_episode_steps > 0 + ), f"Expect the `max_episode_steps` to be positive, actually: {max_episode_steps}" gym.utils.RecordConstructorArgs.__init__( self, max_episode_steps=max_episode_steps ) diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index 2d5261963..b5ac8e899 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -57,3 +57,24 @@ def patched_step(_action): _, _, terminated, truncated, _ = env.step(env.action_space.sample()) assert terminated is True assert truncated is True + + +def test_max_episode_steps(): + env = gym.make("CartPole-v1", disable_env_checker=True) + + assert env.spec.max_episode_steps == 500 + assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10 + + with pytest.raises( + AssertionError, + match="Expect the `max_episode_steps` to be positive, actually: -1", + ): + TimeLimit(env, max_episode_steps=-1) + + with pytest.raises( + AssertionError, + match="Expect the `max_episode_steps` to be positive, actually: None", + ): + TimeLimit(env, max_episode_steps=None) + + env.close()