Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VectorizeMode enum for make_vec(..., vectorization_mode) #767

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gymnasium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
registry,
pprint_registry,
make_vec,
VectorizeMode,
register_envs,
)

Expand All @@ -38,6 +39,7 @@
"spec",
"register",
"registry",
"VectorizeMode",
"pprint_registry",
"register_envs",
# module folders
Expand Down
44 changes: 31 additions & 13 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from types import ModuleType
from typing import Any, Callable, Iterable, Sequence

Expand All @@ -37,6 +38,7 @@
"current_namespace",
"EnvSpec",
"WrapperSpec",
"VectorizeMode",
# Functions
"register",
"make",
Expand All @@ -57,7 +59,7 @@ def __call__(self, **kwargs: Any) -> Env:
class VectorEnvCreator(Protocol):
"""Function type expected for an environment."""

def __call__(self, **kwargs: Any) -> gym.experimental.vector.VectorEnv:
def __call__(self, **kwargs: Any) -> gym.vector.VectorEnv:
...


Expand Down Expand Up @@ -249,6 +251,14 @@ def pprint(
print(output)


class VectorizeMode(Enum):
"""All possible vectorization modes used in `make_vec`."""

ASYNC = "async"
SYNC = "sync"
VECTOR_ENTRY_POINT = "vector_entry_point"


# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
Expand Down Expand Up @@ -809,7 +819,7 @@ def make(
def make_vec(
id: str | EnvSpec,
num_envs: int = 1,
vectorization_mode: str | None = None,
vectorization_mode: VectorizeMode | str | None = None,
vector_kwargs: dict[str, Any] | None = None,
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
**kwargs,
Expand All @@ -822,9 +832,9 @@ def make_vec(
Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create
vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists,
vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``),
this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. Recommended to use the :class:`VectorizeMode` enum rather than strings.
vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
**kwargs: Additional arguments passed to the base environment constructor.
Expand Down Expand Up @@ -856,12 +866,21 @@ def make_vec(

env_spec_kwargs.update(kwargs)

# Update the vectorization_mode if None
# Specify the vectorization mode if None or update to a `VectorizeMode`
if vectorization_mode is None:
if id_env_spec.vector_entry_point is not None:
vectorization_mode = "vector_entry_point"
vectorization_mode = VectorizeMode.VECTOR_ENTRY_POINT
else:
vectorization_mode = "sync"
vectorization_mode = VectorizeMode.SYNC
else:
try:
vectorization_mode = VectorizeMode(vectorization_mode)
except ValueError:
raise ValueError(
f"Invalid vectorization mode: {vectorization_mode!r}, "
f"valid modes: {[mode.value for mode in VectorizeMode]}"
)
assert isinstance(vectorization_mode, VectorizeMode)

def create_single_env() -> Env:
single_env = make(id_env_spec.id, **env_spec_kwargs.copy())
Expand All @@ -870,7 +889,7 @@ def create_single_env() -> Env:
single_env = wrapper(single_env)
return single_env

if vectorization_mode == "sync":
if vectorization_mode == VectorizeMode.SYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
Expand All @@ -880,7 +899,7 @@ def create_single_env() -> Env:
env_fns=(create_single_env for _ in range(num_envs)),
**vector_kwargs,
)
elif vectorization_mode == "async":
elif vectorization_mode == VectorizeMode.ASYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
Expand All @@ -890,7 +909,7 @@ def create_single_env() -> Env:
env_fns=[create_single_env for _ in range(num_envs)],
**vector_kwargs,
)
elif vectorization_mode == "vector_entry_point":
elif vectorization_mode == VectorizeMode.VECTOR_ENTRY_POINT:
entry_point = id_env_spec.vector_entry_point
if entry_point is None:
raise error.Error(
Expand All @@ -910,15 +929,14 @@ def create_single_env() -> Env:

env = env_creator(num_envs=num_envs, **vector_kwargs)
else:
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
raise error.Error(f"Unknown vectorization mode: {vectorization_mode}")

# Copies the environment creation specification and kwargs to add to the environment specification details
copied_id_spec = copy.deepcopy(id_env_spec)
copied_id_spec.kwargs = env_spec_kwargs
if num_envs != 1:
copied_id_spec.kwargs["num_envs"] = num_envs
if vectorization_mode != "async":
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value
if vector_kwargs is not None:
copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
if wrappers is not None:
Expand Down
46 changes: 41 additions & 5 deletions tests/envs/registration/test_make_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import gymnasium as gym
from gymnasium import VectorizeMode
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
Expand Down Expand Up @@ -45,11 +46,17 @@ def test_make_vec_vectorization_mode():
assert isinstance(env, SyncVectorEnv)
env.close()

# Test `vector_entry_point`
# Test `vector_entry_point` for env specs with and without it
env = gym.make_vec("CartPole-v1", vectorization_mode="vector_entry_point")
assert isinstance(env, CartPoleVectorEnv)
env.close()

env = gym.make_vec(
"CartPole-v1", vectorization_mode=VectorizeMode.VECTOR_ENTRY_POINT
)
assert isinstance(env, CartPoleVectorEnv)
env.close()

with pytest.raises(
gym.error.Error,
match=re.escape(
Expand All @@ -58,12 +65,28 @@ def test_make_vec_vectorization_mode():
):
gym.make_vec("Pendulum-v1", vectorization_mode="vector_entry_point")

# Test `async`
# Test `async` and `sync`
env = gym.make_vec("CartPole-v1", vectorization_mode="async")
assert isinstance(env, AsyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.ASYNC)
assert isinstance(env, AsyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode="sync")
assert isinstance(env, SyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.SYNC)
assert isinstance(env, SyncVectorEnv)
env.close()

# Test environment with only a vector entry point and no entry point
gym.register("VecOnlyEnv-v0", vector_entry_point=CartPoleVectorEnv)
env_spec = gym.spec("VecOnlyEnv-v0")
assert env_spec.entry_point is None and env_spec.vector_entry_point is not None

with pytest.raises(
gym.error.Error,
match=re.escape(
Expand All @@ -73,9 +96,22 @@ def test_make_vec_vectorization_mode():
gym.make_vec("VecOnlyEnv-v0", vectorization_mode="async")
del gym.registry["VecOnlyEnv-v0"]

env = gym.make_vec("CartPole-v1", vectorization_mode="sync")
assert isinstance(env, SyncVectorEnv)
env.close()
# Test with invalid vectorization mode
with pytest.raises(
ValueError,
match=re.escape(
"Invalid vectorization mode: 'invalid', valid modes: ['async', 'sync', 'vector_entry_point']"
),
):
gym.make_vec("CartPole-v1", vectorization_mode="invalid")

with pytest.raises(
ValueError,
match=re.escape(
"Invalid vectorization mode: 123, valid modes: ['async', 'sync', 'vector_entry_point']"
),
):
gym.make_vec("CartPole-v1", vectorization_mode=123)


def test_make_vec_wrappers():
Expand Down
Loading