Skip to content

Commit

Permalink
Add a OneOf space for exclusive unions (#812)
Browse files Browse the repository at this point in the history
Co-authored-by: pseudo-rnd-thoughts <[email protected]>
  • Loading branch information
RedTachyon and pseudo-rnd-thoughts authored Mar 11, 2024
1 parent fd4ae52 commit 2b2e853
Show file tree
Hide file tree
Showing 12 changed files with 353 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/api/spaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values
* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces
```

## Utility functions
Expand Down
5 changes: 5 additions & 0 deletions docs/api/spaces/composite.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@
.. automethod:: gymnasium.spaces.Graph.sample
.. automethod:: gymnasium.spaces.Graph.seed
.. autoclass:: gymnasium.spaces.OneOf
.. automethod:: gymnasium.spaces.OneOf.sample
.. automethod:: gymnasium.spaces.OneOf.seed
```
2 changes: 2 additions & 0 deletions gymnasium/spaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.spaces.graph import Graph, GraphInstance
from gymnasium.spaces.multi_binary import MultiBinary
from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.oneof import OneOf
from gymnasium.spaces.sequence import Sequence
from gymnasium.spaces.space import Space
from gymnasium.spaces.text import Text
Expand All @@ -38,6 +39,7 @@
"Tuple",
"Sequence",
"Dict",
"OneOf",
# util functions (there are more utility functions in vector/utils/spaces.py)
"flatdim",
"flatten_space",
Expand Down
158 changes: 158 additions & 0 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations

import collections.abc
import typing
from typing import Any, Iterable

import numpy as np

from gymnasium.spaces.space import Space


class OneOf(Space[Any]):
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
Elements of this space are elements of one of the constituent spaces.
Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0)
>>> observation_space[0]
Discrete(2)
>>> observation_space[1]
Box(-1.0, 1.0, (2,), float32)
>>> len(observation_space)
2
"""

def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`OneOf` space.
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces:
assert isinstance(
space, Space
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed)

@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)

def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)

return seeds

def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"

mask = mask[subspace_idx]

return subspace_idx, subspace.sample(mask=mask)

def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x

return isinstance(x, tuple) and self.spaces[idx].contains(value)

def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]

def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]

def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]

def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)

def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, OneOf) and self.spaces == other.spaces
51 changes: 51 additions & 0 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
return space.max_length


@flatdim.register(OneOf)
def _flatdim_oneof(space: OneOf) -> int:
return 1 + max(flatdim(s) for s in space.spaces)


T = TypeVar("T")
FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
Expand Down Expand Up @@ -256,6 +262,22 @@ def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x)


@flatten.register(OneOf)
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
idx, sample = x
sub_space = space.spaces[idx]
flat_sample = flatten(sub_space, sample)

max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
)
flat_sample = np.concatenate([flat_sample, padding])

return np.concatenate([[idx], flat_sample])


@singledispatch
def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space.
Expand Down Expand Up @@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
return tuple(unflatten(space.feature_space, item) for item in x)


@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
sub_space = space.spaces[idx]

original_size = flatdim(sub_space)
trimmed_sample = x[1 : 1 + original_size]

return idx, unflatten(sub_space, trimmed_sample)


@singledispatch
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible.
Expand Down Expand Up @@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space), stack=space.stack)


@flatten_space.register(OneOf)
def _flatten_space_oneof(space: OneOf) -> Box:
num_subspaces = len(space.spaces)
max_flatdim = max(flatdim(s) for s in space.spaces) + 1

lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])

overall_low = np.min(lows)
overall_high = np.max(highs)

low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])

dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
36 changes: 35 additions & 1 deletion gymnasium/vector/utils/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)


@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)


@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
Expand Down Expand Up @@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):


@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
def _read_text_from_shared_memory(
space: Text, shared_memory, n: int = 1
) -> tuple[str, ...]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
Expand All @@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
)


@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
)


@singledispatch
def write_to_shared_memory(
space: Space,
Expand Down Expand Up @@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
destination[index * size : (index + 1) * size],
flatten(space, values),
)


@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])

for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)
12 changes: 10 additions & 2 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1):
@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(OneOf)
@batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
Expand Down Expand Up @@ -297,6 +299,7 @@ def _concatenate_dict(
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
@concatenate.register(OneOf)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items)

Expand Down Expand Up @@ -336,7 +339,7 @@ def create_empty_array(
)


# It is possible for the some of the Box low to be greater than 0, then array is not in space
# It is possible for some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
Expand Down Expand Up @@ -402,6 +405,11 @@ def _create_empty_array_sequence(
return tuple(tuple() for _ in range(n))


@create_empty_array.register(OneOf)
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
return tuple(tuple() for _ in range(n))


@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None
Loading

0 comments on commit 2b2e853

Please sign in to comment.