diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 0855871a0..4ae2255b9 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -149,11 +149,17 @@ def test_pickle_env(env: gym.Env): if env.metadata.get("jax", False): env = gym.wrappers.JaxToNumpy(env) + action = env.action_space.sample() + + env_reset = env.reset(seed=123) + env_step = env.step(action) + pickled_env = pickle.loads(pickle.dumps(env)) + pickle_reset = pickled_env.reset(seed=123) + pickle_step = pickled_env.step(action) - data_equivalence(env.reset(), pickled_env.reset()) + assert data_equivalence(env_reset, pickle_reset) + assert data_equivalence(env_step, pickle_step) - action = env.action_space.sample() - data_equivalence(env.step(action), pickled_env.step(action)) env.close() pickled_env.close()