Skip to content

Commit

Permalink
Add VectorizeMode enum for make_vec(..., vectorization_mode) (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Dec 3, 2023
1 parent e9c66e4 commit b57b913
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
2 changes: 2 additions & 0 deletions gymnasium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
registry,
pprint_registry,
make_vec,
VectorizeMode,
register_envs,
)

Expand All @@ -38,6 +39,7 @@
"spec",
"register",
"registry",
"VectorizeMode",
"pprint_registry",
"register_envs",
# module folders
Expand Down
44 changes: 31 additions & 13 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from types import ModuleType
from typing import Any, Callable, Iterable, Sequence

Expand All @@ -37,6 +38,7 @@
"current_namespace",
"EnvSpec",
"WrapperSpec",
"VectorizeMode",
# Functions
"register",
"make",
Expand All @@ -57,7 +59,7 @@ def __call__(self, **kwargs: Any) -> Env:
class VectorEnvCreator(Protocol):
"""Function type expected for an environment."""

def __call__(self, **kwargs: Any) -> gym.experimental.vector.VectorEnv:
def __call__(self, **kwargs: Any) -> gym.vector.VectorEnv:
...


Expand Down Expand Up @@ -249,6 +251,14 @@ def pprint(
print(output)


class VectorizeMode(Enum):
"""All possible vectorization modes used in `make_vec`."""

ASYNC = "async"
SYNC = "sync"
VECTOR_ENTRY_POINT = "vector_entry_point"


# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
Expand Down Expand Up @@ -809,7 +819,7 @@ def make(
def make_vec(
id: str | EnvSpec,
num_envs: int = 1,
vectorization_mode: str | None = None,
vectorization_mode: VectorizeMode | str | None = None,
vector_kwargs: dict[str, Any] | None = None,
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
**kwargs,
Expand All @@ -822,9 +832,9 @@ def make_vec(
Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create
vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists,
vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``),
this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. Recommended to use the :class:`VectorizeMode` enum rather than strings.
vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
**kwargs: Additional arguments passed to the base environment constructor.
Expand Down Expand Up @@ -856,12 +866,21 @@ def make_vec(

env_spec_kwargs.update(kwargs)

# Update the vectorization_mode if None
# Specify the vectorization mode if None or update to a `VectorizeMode`
if vectorization_mode is None:
if id_env_spec.vector_entry_point is not None:
vectorization_mode = "vector_entry_point"
vectorization_mode = VectorizeMode.VECTOR_ENTRY_POINT
else:
vectorization_mode = "sync"
vectorization_mode = VectorizeMode.SYNC
else:
try:
vectorization_mode = VectorizeMode(vectorization_mode)
except ValueError:
raise ValueError(
f"Invalid vectorization mode: {vectorization_mode!r}, "
f"valid modes: {[mode.value for mode in VectorizeMode]}"
)
assert isinstance(vectorization_mode, VectorizeMode)

def create_single_env() -> Env:
single_env = make(id_env_spec.id, **env_spec_kwargs.copy())
Expand All @@ -870,7 +889,7 @@ def create_single_env() -> Env:
single_env = wrapper(single_env)
return single_env

if vectorization_mode == "sync":
if vectorization_mode == VectorizeMode.SYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
Expand All @@ -880,7 +899,7 @@ def create_single_env() -> Env:
env_fns=(create_single_env for _ in range(num_envs)),
**vector_kwargs,
)
elif vectorization_mode == "async":
elif vectorization_mode == VectorizeMode.ASYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
Expand All @@ -890,7 +909,7 @@ def create_single_env() -> Env:
env_fns=[create_single_env for _ in range(num_envs)],
**vector_kwargs,
)
elif vectorization_mode == "vector_entry_point":
elif vectorization_mode == VectorizeMode.VECTOR_ENTRY_POINT:
entry_point = id_env_spec.vector_entry_point
if entry_point is None:
raise error.Error(
Expand All @@ -910,15 +929,14 @@ def create_single_env() -> Env:

env = env_creator(num_envs=num_envs, **vector_kwargs)
else:
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
raise error.Error(f"Unknown vectorization mode: {vectorization_mode}")

# Copies the environment creation specification and kwargs to add to the environment specification details
copied_id_spec = copy.deepcopy(id_env_spec)
copied_id_spec.kwargs = env_spec_kwargs
if num_envs != 1:
copied_id_spec.kwargs["num_envs"] = num_envs
if vectorization_mode != "async":
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value
if vector_kwargs is not None:
copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
if wrappers is not None:
Expand Down
46 changes: 41 additions & 5 deletions tests/envs/registration/test_make_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import gymnasium as gym
from gymnasium import VectorizeMode
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
Expand Down Expand Up @@ -45,11 +46,17 @@ def test_make_vec_vectorization_mode():
assert isinstance(env, SyncVectorEnv)
env.close()

# Test `vector_entry_point`
# Test `vector_entry_point` for env specs with and without it
env = gym.make_vec("CartPole-v1", vectorization_mode="vector_entry_point")
assert isinstance(env, CartPoleVectorEnv)
env.close()

env = gym.make_vec(
"CartPole-v1", vectorization_mode=VectorizeMode.VECTOR_ENTRY_POINT
)
assert isinstance(env, CartPoleVectorEnv)
env.close()

with pytest.raises(
gym.error.Error,
match=re.escape(
Expand All @@ -58,12 +65,28 @@ def test_make_vec_vectorization_mode():
):
gym.make_vec("Pendulum-v1", vectorization_mode="vector_entry_point")

# Test `async`
# Test `async` and `sync`
env = gym.make_vec("CartPole-v1", vectorization_mode="async")
assert isinstance(env, AsyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.ASYNC)
assert isinstance(env, AsyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode="sync")
assert isinstance(env, SyncVectorEnv)
env.close()

env = gym.make_vec("CartPole-v1", vectorization_mode=VectorizeMode.SYNC)
assert isinstance(env, SyncVectorEnv)
env.close()

# Test environment with only a vector entry point and no entry point
gym.register("VecOnlyEnv-v0", vector_entry_point=CartPoleVectorEnv)
env_spec = gym.spec("VecOnlyEnv-v0")
assert env_spec.entry_point is None and env_spec.vector_entry_point is not None

with pytest.raises(
gym.error.Error,
match=re.escape(
Expand All @@ -73,9 +96,22 @@ def test_make_vec_vectorization_mode():
gym.make_vec("VecOnlyEnv-v0", vectorization_mode="async")
del gym.registry["VecOnlyEnv-v0"]

env = gym.make_vec("CartPole-v1", vectorization_mode="sync")
assert isinstance(env, SyncVectorEnv)
env.close()
# Test with invalid vectorization mode
with pytest.raises(
ValueError,
match=re.escape(
"Invalid vectorization mode: 'invalid', valid modes: ['async', 'sync', 'vector_entry_point']"
),
):
gym.make_vec("CartPole-v1", vectorization_mode="invalid")

with pytest.raises(
ValueError,
match=re.escape(
"Invalid vectorization mode: 123, valid modes: ['async', 'sync', 'vector_entry_point']"
),
):
gym.make_vec("CartPole-v1", vectorization_mode=123)


def test_make_vec_wrappers():
Expand Down

0 comments on commit b57b913

Please sign in to comment.