-
-
Notifications
You must be signed in to change notification settings - Fork 903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a OneOf space for exclusive unions #812
Changes from all commits
35f1805
27bf126
785860f
cf8baed
f753f1a
31061a5
60eb634
9fcbbd5
be5389d
3904285
ca5d6f9
d871c04
54d79e8
4612418
6513099
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
"""Implementation of a space that represents the cartesian product of other spaces.""" | ||
from __future__ import annotations | ||
|
||
import collections.abc | ||
import typing | ||
from typing import Any, Iterable | ||
|
||
import numpy as np | ||
|
||
from gymnasium.spaces.space import Space | ||
|
||
|
||
class OneOf(Space[Any]): | ||
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances. | ||
|
||
Elements of this space are elements of one of the constituent spaces. | ||
|
||
Example: | ||
>>> from gymnasium.spaces import OneOf, Box, Discrete | ||
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) | ||
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box | ||
(1, array([-0.3991573 , 0.21649833], dtype=float32)) | ||
>>> observation_space.sample() # this time the Discrete space was sampled as index=0 | ||
(0, 0) | ||
>>> observation_space[0] | ||
Discrete(2) | ||
>>> observation_space[1] | ||
Box(-1.0, 1.0, (2,), float32) | ||
>>> len(observation_space) | ||
2 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
spaces: Iterable[Space[Any]], | ||
seed: int | typing.Sequence[int] | np.random.Generator | None = None, | ||
): | ||
r"""Constructor of :class:`OneOf` space. | ||
|
||
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`. | ||
|
||
Args: | ||
spaces (Iterable[Space]): The spaces that are involved in the cartesian product. | ||
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. | ||
""" | ||
self.spaces = tuple(spaces) | ||
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported." | ||
for space in self.spaces: | ||
assert isinstance( | ||
space, Space | ||
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}" | ||
super().__init__(None, None, seed) | ||
|
||
@property | ||
def is_np_flattenable(self): | ||
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" | ||
return all(space.is_np_flattenable for space in self.spaces) | ||
|
||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: | ||
"""Seed the PRNG of this space and all subspaces. | ||
|
||
Depending on the type of seed, the subspaces will be seeded differently | ||
|
||
* ``None`` - All the subspaces will use a random initial seed | ||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. | ||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. | ||
|
||
Args: | ||
seed: An optional list of ints or int to seed the (sub-)spaces. | ||
""" | ||
if isinstance(seed, collections.abc.Sequence): | ||
assert ( | ||
len(seed) == len(self.spaces) + 1 | ||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}" | ||
seeds = super().seed(seed[0]) | ||
for subseed, space in zip(seed, self.spaces): | ||
seeds += space.seed(subseed) | ||
elif isinstance(seed, int): | ||
seeds = super().seed(seed) | ||
subseeds = self.np_random.integers( | ||
np.iinfo(np.int32).max, size=len(self.spaces) | ||
) | ||
for subspace, subseed in zip(self.spaces, subseeds): | ||
seeds += subspace.seed(int(subseed)) | ||
elif seed is None: | ||
seeds = super().seed(None) | ||
for space in self.spaces: | ||
seeds += space.seed(None) | ||
else: | ||
raise TypeError( | ||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}" | ||
) | ||
|
||
return seeds | ||
|
||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]: | ||
"""Generates a single random sample inside this space. | ||
|
||
This method draws independent samples from the subspaces. | ||
|
||
Args: | ||
mask: An optional tuple of optional masks for each of the subspace's samples, | ||
expects the same number of masks as spaces | ||
|
||
Returns: | ||
Tuple of the subspace's samples | ||
""" | ||
subspace_idx = int(self.np_random.integers(0, len(self.spaces))) | ||
subspace = self.spaces[subspace_idx] | ||
if mask is not None: | ||
assert isinstance( | ||
mask, tuple | ||
), f"Expected type of mask is tuple, actual type: {type(mask)}" | ||
assert len(mask) == len( | ||
self.spaces | ||
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}" | ||
|
||
mask = mask[subspace_idx] | ||
|
||
return subspace_idx, subspace.sample(mask=mask) | ||
|
||
def contains(self, x: tuple[int, Any]) -> bool: | ||
"""Return boolean specifying if x is a valid member of this space.""" | ||
(idx, value) = x | ||
|
||
return isinstance(x, tuple) and self.spaces[idx].contains(value) | ||
|
||
def __repr__(self) -> str: | ||
"""Gives a string representation of this space.""" | ||
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")" | ||
|
||
def to_jsonable( | ||
self, sample_n: typing.Sequence[tuple[int, Any]] | ||
) -> list[list[Any]]: | ||
"""Convert a batch of samples from this space to a JSONable data type.""" | ||
return [ | ||
[int(i), self.spaces[i].to_jsonable([subsample])[0]] | ||
for (i, subsample) in sample_n | ||
] | ||
|
||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: | ||
"""Convert a JSONable data type to a batch of samples from this space.""" | ||
return [ | ||
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0]) | ||
for space_idx, jsonable_sample in sample_n | ||
] | ||
|
||
def __getitem__(self, index: int) -> Space[Any]: | ||
"""Get the subspace at specific `index`.""" | ||
return self.spaces[index] | ||
|
||
def __len__(self) -> int: | ||
"""Get the number of subspaces that are involved in the cartesian product.""" | ||
return len(self.spaces) | ||
|
||
def __eq__(self, other: Any) -> bool: | ||
"""Check whether ``other`` is equivalent to this instance.""" | ||
return isinstance(other, OneOf) and self.spaces == other.spaces |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
GraphInstance, | ||
MultiBinary, | ||
MultiDiscrete, | ||
OneOf, | ||
Sequence, | ||
Space, | ||
Text, | ||
|
@@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int: | |
return space.max_length | ||
|
||
|
||
@flatdim.register(OneOf) | ||
def _flatdim_oneof(space: OneOf) -> int: | ||
return 1 + max(flatdim(s) for s in space.spaces) | ||
|
||
|
||
T = TypeVar("T") | ||
FlatType = Union[ | ||
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance | ||
|
@@ -256,6 +262,22 @@ def _flatten_sequence( | |
return tuple(flatten(space.feature_space, item) for item in x) | ||
|
||
|
||
@flatten.register(OneOf) | ||
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]: | ||
idx, sample = x | ||
sub_space = space.spaces[idx] | ||
flat_sample = flatten(sub_space, sample) | ||
|
||
max_flatdim = flatdim(space) - 1 # Don't include the index | ||
if flat_sample.size < max_flatdim: | ||
padding = np.full( | ||
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm now wondering if it might be better to use Either way, the array will be filled with some throwaway data, but with empty we're not pretending that this data has any actual meaning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran a quick check and this fails |
||
) | ||
flat_sample = np.concatenate([flat_sample, padding]) | ||
|
||
return np.concatenate([[idx], flat_sample]) | ||
|
||
|
||
@singledispatch | ||
def unflatten(space: Space[T], x: FlatType) -> T: | ||
"""Unflatten a data point from a space. | ||
|
@@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | |
return tuple(unflatten(space.feature_space, item) for item in x) | ||
|
||
|
||
@unflatten.register(OneOf) | ||
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]: | ||
idx = int(x[0]) | ||
sub_space = space.spaces[idx] | ||
|
||
original_size = flatdim(sub_space) | ||
trimmed_sample = x[1 : 1 + original_size] | ||
|
||
return idx, unflatten(sub_space, trimmed_sample) | ||
|
||
|
||
@singledispatch | ||
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: | ||
"""Flatten a space into a space that is as flat as possible. | ||
|
@@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box: | |
@flatten_space.register(Sequence) | ||
def _flatten_space_sequence(space: Sequence) -> Sequence: | ||
return Sequence(flatten_space(space.feature_space), stack=space.stack) | ||
|
||
|
||
@flatten_space.register(OneOf) | ||
def _flatten_space_oneof(space: OneOf) -> Box: | ||
num_subspaces = len(space.spaces) | ||
max_flatdim = max(flatdim(s) for s in space.spaces) + 1 | ||
|
||
lows = np.array([np.min(flatten_space(s).low) for s in space.spaces]) | ||
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces]) | ||
|
||
overall_low = np.min(lows) | ||
overall_high = np.max(highs) | ||
|
||
low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)]) | ||
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) | ||
|
||
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) | ||
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we support masking of the subspace index? This adds more complexity that might not be needed