Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add-more-introductory-pages
Browse files Browse the repository at this point in the history
# Conflicts:
#	gymnasium/wrappers/vector/common.py
pseudo-rnd-thoughts committed Dec 5, 2023
2 parents 6cc7f19 + b57b913 commit c1543ff
Showing 28 changed files with 807 additions and 575 deletions.
2 changes: 1 addition & 1 deletion docs/environments/third_party_environments.md
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ many-agent RL ([MAgent2](https://magent2.farama.org/)),

*This page contains environments which are not maintained by Farama Foundation and, as such, cannot be guaranteed to function as intended.*

*If you'd like to contribute an environment, please reach out on [Discord](https://discord.gg/nHg2JRN489).*
*If you'd like to contribute an environment, please reach out on [Discord](https://discord.gg/bnJ6kubTg6).*

### [CARL: context adaptive RL](https://github.com/automl/CARL)

2 changes: 2 additions & 0 deletions gymnasium/__init__.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
registry,
pprint_registry,
make_vec,
VectorizeMode,
register_envs,
)

@@ -38,6 +39,7 @@
"spec",
"register",
"registry",
"VectorizeMode",
"pprint_registry",
"register_envs",
# module folders
37 changes: 10 additions & 27 deletions gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
@@ -151,6 +151,8 @@ def __init__(

self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)

self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)

self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)

if self.render_mode == "rgb_array":
@@ -214,10 +216,9 @@ def step(self, action: ActType):
info = self.func_env.transition_info(self.state, action, next_state)

done = jnp.logical_or(terminated, truncated)
if jnp.any(done):
final_obs = self.func_env.observation(next_state)

to_reset = jnp.where(done)[0]
if jnp.any(self.autoreset_envs):
to_reset = jnp.where(self.autoreset_envs)[0]
reset_count = to_reset.shape[0]

rng, self.rng = jrng.split(self.rng)
@@ -228,34 +229,16 @@ def step(self, action: ActType):
next_state = self.state.at[to_reset].set(new_initials)
self.steps = self.steps.at[to_reset].set(0)

# Get the final observations and infos
info["final_observation"] = np.array([None for _ in range(self.num_envs)])
info["final_info"] = np.array([None for _ in range(self.num_envs)])

info["_final_observation"] = np.array([False for _ in range(self.num_envs)])
info["_final_info"] = np.array([False for _ in range(self.num_envs)])

# TODO: this can maybe be optimized, but right now I don't know how
for i in to_reset:
info["final_observation"][i] = final_obs[i]
info["final_info"][i] = {
k: v[i]
for k, v in info.items()
if k
not in {
"final_observation",
"final_info",
"_final_observation",
"_final_info",
}
}

info["_final_observation"][i] = True
info["_final_info"][i] = True
self.autoreset_envs = done

observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)

reward = jax_to_numpy(reward)

terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)

self.state = next_state

return observation, reward, terminated, truncated, info
44 changes: 31 additions & 13 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from types import ModuleType
from typing import Any, Callable, Iterable, Sequence

@@ -37,6 +38,7 @@
"current_namespace",
"EnvSpec",
"WrapperSpec",
"VectorizeMode",
# Functions
"register",
"make",
@@ -57,7 +59,7 @@ def __call__(self, **kwargs: Any) -> Env:
class VectorEnvCreator(Protocol):
"""Function type expected for an environment."""

def __call__(self, **kwargs: Any) -> gym.experimental.vector.VectorEnv:
def __call__(self, **kwargs: Any) -> gym.vector.VectorEnv:
...


@@ -249,6 +251,14 @@ def pprint(
print(output)


class VectorizeMode(Enum):
"""All possible vectorization modes used in `make_vec`."""

ASYNC = "async"
SYNC = "sync"
VECTOR_ENTRY_POINT = "vector_entry_point"


# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
@@ -809,7 +819,7 @@ def make(
def make_vec(
id: str | EnvSpec,
num_envs: int = 1,
vectorization_mode: str | None = None,
vectorization_mode: VectorizeMode | str | None = None,
vector_kwargs: dict[str, Any] | None = None,
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
**kwargs,
@@ -822,9 +832,9 @@ def make_vec(
Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create
vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists,
vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``),
this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``.
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. Recommended to use the :class:`VectorizeMode` enum rather than strings.
vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
**kwargs: Additional arguments passed to the base environment constructor.
@@ -856,12 +866,21 @@ def make_vec(

env_spec_kwargs.update(kwargs)

# Update the vectorization_mode if None
# Specify the vectorization mode if None or update to a `VectorizeMode`
if vectorization_mode is None:
if id_env_spec.vector_entry_point is not None:
vectorization_mode = "vector_entry_point"
vectorization_mode = VectorizeMode.VECTOR_ENTRY_POINT
else:
vectorization_mode = "sync"
vectorization_mode = VectorizeMode.SYNC
else:
try:
vectorization_mode = VectorizeMode(vectorization_mode)
except ValueError:
raise ValueError(
f"Invalid vectorization mode: {vectorization_mode!r}, "
f"valid modes: {[mode.value for mode in VectorizeMode]}"
)
assert isinstance(vectorization_mode, VectorizeMode)

def create_single_env() -> Env:
single_env = make(id_env_spec.id, **env_spec_kwargs.copy())
@@ -870,7 +889,7 @@ def create_single_env() -> Env:
single_env = wrapper(single_env)
return single_env

if vectorization_mode == "sync":
if vectorization_mode == VectorizeMode.SYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
@@ -880,7 +899,7 @@ def create_single_env() -> Env:
env_fns=(create_single_env for _ in range(num_envs)),
**vector_kwargs,
)
elif vectorization_mode == "async":
elif vectorization_mode == VectorizeMode.ASYNC:
if id_env_spec.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
@@ -890,7 +909,7 @@ def create_single_env() -> Env:
env_fns=[create_single_env for _ in range(num_envs)],
**vector_kwargs,
)
elif vectorization_mode == "vector_entry_point":
elif vectorization_mode == VectorizeMode.VECTOR_ENTRY_POINT:
entry_point = id_env_spec.vector_entry_point
if entry_point is None:
raise error.Error(
@@ -910,15 +929,14 @@ def create_single_env() -> Env:

env = env_creator(num_envs=num_envs, **vector_kwargs)
else:
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
raise error.Error(f"Unknown vectorization mode: {vectorization_mode}")

# Copies the environment creation specification and kwargs to add to the environment specification details
copied_id_spec = copy.deepcopy(id_env_spec)
copied_id_spec.kwargs = env_spec_kwargs
if num_envs != 1:
copied_id_spec.kwargs["num_envs"] = num_envs
if vectorization_mode != "async":
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value
if vector_kwargs is not None:
copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
if wrappers is not None:
125 changes: 3 additions & 122 deletions gymnasium/error.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
"""Set of Error classes for gymnasium."""
import warnings


class Error(Exception):
"""Error superclass."""


# Local errors


class Unregistered(Error):
"""Raised when the user requests an item from the registry that does not actually exist."""


class UnregisteredEnv(Unregistered):
# Registration errors
class UnregisteredEnv(Error):
"""Raised when the user requests an env from the registry that does not actually exist."""


@@ -29,10 +22,6 @@ class VersionNotFound(UnregisteredEnv):
"""Raised when the user requests an env from the registry where the version doesn't exist."""


class UnregisteredBenchmark(Unregistered):
"""Raised when the user requests an env from the registry that does not actually exist."""


class DeprecatedEnv(Error):
"""Raised when the user requests an env from the registry with an older version number than the latest env with the same name."""

@@ -41,10 +30,7 @@ class RegistrationError(Error):
"""Raised when the user attempts to register an invalid env. For example, an unversioned env when a versioned env exists."""


class UnseedableEnv(Error):
"""Raised when the user tries to seed an env that does not support seeding."""


# Environment errors
class DependencyNotInstalled(Error):
"""Raised when the user has not installed a dependency."""

@@ -61,10 +47,6 @@ class ResetNeeded(Error):
"""When the order enforcing is violated, i.e. step or render is called before reset."""


class ResetNotAllowed(Error):
"""When the monitor is active, raised when the user tries to step an environment that's not yet terminated or truncated."""


class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space."""

@@ -81,113 +63,12 @@ class InvalidBound(Error):
"""Raised when the clipping an array with invalid upper and/or lower bound."""


# API errors


class APIError(Error):
"""Deprecated, to be removed at gymnasium 1.0."""

def __init__(
self,
message=None,
http_body=None,
http_status=None,
json_body=None,
headers=None,
):
"""Initialise API error."""
super().__init__(message)

warnings.warn("APIError is deprecated and will be removed at gymnasium 1.0")

if http_body and hasattr(http_body, "decode"):
try:
http_body = http_body.decode("utf-8")
except Exception:
http_body = "<Could not decode body as utf-8.>"

self._message = message
self.http_body = http_body
self.http_status = http_status
self.json_body = json_body
self.headers = headers or {}
self.request_id = self.headers.get("request-id", None)

def __unicode__(self):
"""Returns a string, if request_id is not None then make message other use the _message."""
if self.request_id is not None:
msg = self._message or "<empty message>"
return f"Request {self.request_id}: {msg}"
else:
return self._message

def __str__(self):
"""Returns the __unicode__."""
return self.__unicode__()


class APIConnectionError(APIError):
"""Deprecated, to be removed at gymnasium 1.0."""


class InvalidRequestError(APIError):
"""Deprecated, to be removed at gymnasium 1.0."""

def __init__(
self,
message,
param,
http_body=None,
http_status=None,
json_body=None,
headers=None,
):
"""Initialises the invalid request error."""
super().__init__(message, http_body, http_status, json_body, headers)
self.param = param


class AuthenticationError(APIError):
"""Deprecated, to be removed at gymnasium 1.0."""


class RateLimitError(APIError):
"""Deprecated, to be removed at gymnasium 1.0."""


# Video errors


class VideoRecorderError(Error):
"""Unused error."""


class InvalidFrame(Error):
"""Error message when an invalid frame is captured."""


# Wrapper errors


class DoubleWrapperError(Error):
"""Error message for when using double wrappers."""


class WrapAfterConfigureError(Error):
"""Error message for using wrap after configure."""


class RetriesExceededError(Error):
"""Error message for retries exceeding set number."""


class DeprecatedWrapper(ImportError):
"""Error message for importing an old version of a wrapper."""


# Vectorized environments errors


class AlreadyPendingCallError(Exception):
"""Raised when `reset`, or `step` is called asynchronously (e.g. with `reset_async`, or `step_async` respectively), and `reset_async`, or `step_async` (respectively) is called again (without a complete call to `reset_wait`, or `step_wait` respectively)."""

Loading

0 comments on commit c1543ff

Please sign in to comment.