diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index 106f7c4b9..8f1e28ad4 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -16,6 +16,7 @@ registry, pprint_registry, make_vec, + VectorizeMode, register_envs, ) @@ -38,6 +39,7 @@ "spec", "register", "registry", + "VectorizeMode", "pprint_registry", "register_envs", # module folders diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index 515c097e4..b968831d9 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -12,6 +12,7 @@ import sys from collections import defaultdict from dataclasses import dataclass, field +from enum import Enum from types import ModuleType from typing import Any, Callable, Iterable, Sequence @@ -37,6 +38,7 @@ "current_namespace", "EnvSpec", "WrapperSpec", + "VectorizeMode", # Functions "register", "make", @@ -57,7 +59,7 @@ def __call__(self, **kwargs: Any) -> Env: class VectorEnvCreator(Protocol): """Function type expected for an environment.""" - def __call__(self, **kwargs: Any) -> gym.experimental.vector.VectorEnv: + def __call__(self, **kwargs: Any) -> gym.vector.VectorEnv: ... @@ -249,6 +251,14 @@ def pprint( print(output) +class VectorizeMode(Enum): + """All possible vectorization modes used in `make_vec`.""" + + ASYNC = "async" + SYNC = "sync" + VECTOR_ENTRY_POINT = "vector_entry_point" + + # Global registry of environments. Meant to be accessed through `register` and `make` registry: dict[str, EnvSpec] = {} current_namespace: str | None = None @@ -809,7 +819,7 @@ def make( def make_vec( id: str | EnvSpec, num_envs: int = 1, - vectorization_mode: str | None = None, + vectorization_mode: VectorizeMode | str | None = None, vector_kwargs: dict[str, Any] | None = None, wrappers: Sequence[Callable[[Env], Wrapper]] | None = None, **kwargs, @@ -822,9 +832,9 @@ def make_vec( Args: id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' num_envs: Number of environments to create - vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists, + vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``), this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`. - Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. + Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. Recommended to use the :class:`VectorizeMode` enum rather than strings. vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``. wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode. **kwargs: Additional arguments passed to the base environment constructor. @@ -856,12 +866,21 @@ def make_vec( env_spec_kwargs.update(kwargs) - # Update the vectorization_mode if None + # Specify the vectorization mode if None or update to a `VectorizeMode` if vectorization_mode is None: if id_env_spec.vector_entry_point is not None: - vectorization_mode = "vector_entry_point" + vectorization_mode = VectorizeMode.VECTOR_ENTRY_POINT else: - vectorization_mode = "sync" + vectorization_mode = VectorizeMode.SYNC + else: + try: + vectorization_mode = VectorizeMode(vectorization_mode) + except ValueError: + raise ValueError( + f"Invalid vectorization mode: {vectorization_mode!r}, " + f"valid modes: {[mode.value for mode in VectorizeMode]}" + ) + assert isinstance(vectorization_mode, VectorizeMode) def create_single_env() -> Env: single_env = make(id_env_spec.id, **env_spec_kwargs.copy()) @@ -870,7 +889,7 @@ def create_single_env() -> Env: single_env = wrapper(single_env) return single_env - if vectorization_mode == "sync": + if vectorization_mode == VectorizeMode.SYNC: if id_env_spec.entry_point is None: raise error.Error( f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined." @@ -880,7 +899,7 @@ def create_single_env() -> Env: env_fns=(create_single_env for _ in range(num_envs)), **vector_kwargs, ) - elif vectorization_mode == "async": + elif vectorization_mode == VectorizeMode.ASYNC: if id_env_spec.entry_point is None: raise error.Error( f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined." @@ -890,7 +909,7 @@ def create_single_env() -> Env: env_fns=[create_single_env for _ in range(num_envs)], **vector_kwargs, ) - elif vectorization_mode == "vector_entry_point": + elif vectorization_mode == VectorizeMode.VECTOR_ENTRY_POINT: entry_point = id_env_spec.vector_entry_point if entry_point is None: raise error.Error( @@ -910,15 +929,14 @@ def create_single_env() -> Env: env = env_creator(num_envs=num_envs, **vector_kwargs) else: - raise error.Error(f"Invalid vectorization mode: {vectorization_mode}") + raise error.Error(f"Unknown vectorization mode: {vectorization_mode}") # Copies the environment creation specification and kwargs to add to the environment specification details copied_id_spec = copy.deepcopy(id_env_spec) copied_id_spec.kwargs = env_spec_kwargs if num_envs != 1: copied_id_spec.kwargs["num_envs"] = num_envs - if vectorization_mode != "async": - copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode + copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value if vector_kwargs is not None: copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs if wrappers is not None: diff --git a/tests/envs/registration/test_make_vec.py b/tests/envs/registration/test_make_vec.py index 866c4f755..96660e9b9 100644 --- a/tests/envs/registration/test_make_vec.py +++ b/tests/envs/registration/test_make_vec.py @@ -4,6 +4,7 @@ import pytest import gymnasium as gym +from gymnasium import VectorizeMode from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv @@ -45,11 +46,17 @@ def test_make_vec_vectorization_mode(): assert isinstance(env, SyncVectorEnv) env.close() - # Test `vector_entry_point` + # Test `vector_entry_point` for env specs with and without it env = gym.make_vec("CartPole-v1", vectorization_mode="vector_entry_point") assert isinstance(env, CartPoleVectorEnv) env.close() + env = gym.make_vec( + "CartPole-v1", vectorization_mode=VectorizeMode.VECTOR_ENTRY_POINT + ) + assert isinstance(env, CartPoleVectorEnv) + env.close() + with pytest.raises( gym.error.Error, match=re.escape( @@ -58,12 +65,28 @@ def test_make_vec_vectorization_mode(): ): gym.make_vec("Pendulum-v1", vectorization_mode="vector_entry_point") - # Test `async` + # Test `async` and `sync` env = gym.make_vec("CartPole-v1", vectorization_mode="async") assert isinstance(env, AsyncVectorEnv) env.close() + env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.ASYNC) + assert isinstance(env, AsyncVectorEnv) + env.close() + + env = gym.make_vec("CartPole-v1", vectorization_mode="sync") + assert isinstance(env, SyncVectorEnv) + env.close() + + env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.SYNC) + assert isinstance(env, SyncVectorEnv) + env.close() + + # Test environment with only a vector entry point and no entry point gym.register("VecOnlyEnv-v0", vector_entry_point=CartPoleVectorEnv) + env_spec = gym.spec("VecOnlyEnv-v0") + assert env_spec.entry_point is None and env_spec.vector_entry_point is not None + with pytest.raises( gym.error.Error, match=re.escape( @@ -73,9 +96,22 @@ def test_make_vec_vectorization_mode(): gym.make_vec("VecOnlyEnv-v0", vectorization_mode="async") del gym.registry["VecOnlyEnv-v0"] - env = gym.make_vec("CartPole-v1", vectorization_mode="sync") - assert isinstance(env, SyncVectorEnv) - env.close() + # Test with invalid vectorization mode + with pytest.raises( + ValueError, + match=re.escape( + "Invalid vectorization mode: 'invalid', valid modes: ['async', 'sync', 'vector_entry_point']" + ), + ): + gym.make_vec("CartPole-v1", vectorization_mode="invalid") + + with pytest.raises( + ValueError, + match=re.escape( + "Invalid vectorization mode: 123, valid modes: ['async', 'sync', 'vector_entry_point']" + ), + ): + gym.make_vec("CartPole-v1", vectorization_mode=123) def test_make_vec_wrappers():