From fb5112fbcafdc65dcd6c65fcbf2f7920d84c6aa1 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Wed, 17 Apr 2024 14:55:46 +0100 Subject: [PATCH] Fix `make_vec` for sync or async and modifying make arguments (#1027) --- gymnasium/envs/registration.py | 4 +++- gymnasium/vector/sync_vector_env.py | 3 ++- tests/envs/registration/test_make_vec.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index fc9176e7a..20d016b00 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -845,7 +845,7 @@ def make_vec( We refer to the Vector environment as the vectorizor while the environment being vectorized is the base or vectorized environment (``vectorizor(vectorized env)``). Args: - id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' + id: Name of the environment. Optionally, a module to import can be included, e.g. 'module:Env-v0' num_envs: Number of environments to create vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``), this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`. @@ -874,6 +874,8 @@ def make_vec( env_spec = copy.deepcopy(env_spec) env_spec_kwargs = env_spec.kwargs + # for sync or async, these parameters should be passed in `make(..., **kwargs)` rather than in the env spec kwargs, therefore, we `reset` the kwargs + env_spec.kwargs = dict() num_envs = env_spec_kwargs.pop("num_envs", num_envs) vectorization_mode = env_spec_kwargs.pop("vectorization_mode", vectorization_mode) diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index e59b2083c..ccecd05a2 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -260,7 +260,8 @@ def set_attr(self, name: str, values: list[Any] | tuple[Any, ...] | Any): def close_extras(self, **kwargs: Any): """Close the environments.""" - [env.close() for env in self.envs] + if hasattr(self, "envs"): + [env.close() for env in self.envs] def _check_spaces(self) -> bool: """Check that each of the environments obs and action spaces are equivalent to the single obs and action space.""" diff --git a/tests/envs/registration/test_make_vec.py b/tests/envs/registration/test_make_vec.py index e82a4551a..6872435ba 100644 --- a/tests/envs/registration/test_make_vec.py +++ b/tests/envs/registration/test_make_vec.py @@ -183,6 +183,9 @@ def test_make_vec_wrappers(): }, ), ("CartPole-v1", {"render_mode": "rgb_array"}), + ("CartPole-v1", {"vectorization_mode": "sync", "max_episode_steps": 5}), + ("CartPole-v1", {"sutton_barto_reward": True}), + ("CartPole-v1", {"vectorization_mode": "sync", "sutton_barto_reward": True}), (gym.spec("CartPole-v1"), {}), (gym.spec("CartPole-v1"), {"num_envs": 3}), (gym.spec("CartPole-v1"), {"vectorization_mode": "sync"}), @@ -199,6 +202,15 @@ def test_make_vec_wrappers(): }, ), (gym.spec("CartPole-v1"), {"render_mode": "rgb_array"}), + ( + gym.spec("CartPole-v1"), + {"vectorization_mode": "sync", "max_episode_steps": 5}, + ), + (gym.spec("CartPole-v1"), {"sutton_barto_reward": True}), + ( + gym.spec("CartPole-v1"), + {"vectorization_mode": "sync", "sutton_barto_reward": True}, + ), ), ) def test_make_vec_with_spec(env_id: str, kwargs: dict):