Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a OneOf space for exclusive unions #812

Merged
merged 15 commits into from
Mar 11, 2024
Merged
3 changes: 2 additions & 1 deletion docs/api/spaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values
* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces
```

## Utility functions
Expand Down
5 changes: 5 additions & 0 deletions docs/api/spaces/composite.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@

.. automethod:: gymnasium.spaces.Graph.sample
.. automethod:: gymnasium.spaces.Graph.seed

.. autoclass:: gymnasium.spaces.OneOf

.. automethod:: gymnasium.spaces.OneOf.sample
.. automethod:: gymnasium.spaces.OneOf.seed
```
2 changes: 2 additions & 0 deletions gymnasium/spaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.spaces.graph import Graph, GraphInstance
from gymnasium.spaces.multi_binary import MultiBinary
from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.oneof import OneOf
from gymnasium.spaces.sequence import Sequence
from gymnasium.spaces.space import Space
from gymnasium.spaces.text import Text
Expand All @@ -38,6 +39,7 @@
"Tuple",
"Sequence",
"Dict",
"OneOf",
# util functions (there are more utility functions in vector/utils/spaces.py)
"flatdim",
"flatten_space",
Expand Down
158 changes: 158 additions & 0 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations

import collections.abc
import typing
from typing import Any, Iterable

import numpy as np

from gymnasium.spaces.space import Space


class OneOf(Space[Any]):
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.

Elements of this space are elements of one of the constituent spaces.

Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0)
>>> observation_space[0]
Discrete(2)
>>> observation_space[1]
Box(-1.0, 1.0, (2,), float32)
>>> len(observation_space)
2
"""

def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`OneOf` space.

The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.

Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces:
assert isinstance(
space, Space
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed)

@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)

def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.

Depending on the type of seed, the subspaces will be seeded differently

* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.

Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)

return seeds

def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.

This method draws independent samples from the subspaces.

Args:
mask: An optional tuple of optional masks for each of the subspace's samples,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support masking of the subspace index? This adds more complexity that might not be needed

expects the same number of masks as spaces

Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"

mask = mask[subspace_idx]

return subspace_idx, subspace.sample(mask=mask)

def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x

return isinstance(x, tuple) and self.spaces[idx].contains(value)

def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]

def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]

def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]

def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)

def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, OneOf) and self.spaces == other.spaces
51 changes: 51 additions & 0 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
return space.max_length


@flatdim.register(OneOf)
def _flatdim_oneof(space: OneOf) -> int:
return 1 + max(flatdim(s) for s in space.spaces)


T = TypeVar("T")
FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
Expand Down Expand Up @@ -256,6 +262,22 @@ def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x)


@flatten.register(OneOf)
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
idx, sample = x
sub_space = space.spaces[idx]
flat_sample = flatten(sub_space, sample)

max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now wondering if it might be better to use np.empty instead if we're not using a dedicated placeholder value.

Either way, the array will be filled with some throwaway data, but with empty we're not pretending that this data has any actual meaning.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a quick check and this fails test_flat_space_contains_flat_points as the flattened version might not be contained within the flattened space.
But this is just a weird artifact of flatten rather than anything actually incorrect with the approach

)
flat_sample = np.concatenate([flat_sample, padding])

return np.concatenate([[idx], flat_sample])


@singledispatch
def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space.
Expand Down Expand Up @@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
return tuple(unflatten(space.feature_space, item) for item in x)


@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
sub_space = space.spaces[idx]

original_size = flatdim(sub_space)
trimmed_sample = x[1 : 1 + original_size]

return idx, unflatten(sub_space, trimmed_sample)


@singledispatch
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible.
Expand Down Expand Up @@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space), stack=space.stack)


@flatten_space.register(OneOf)
def _flatten_space_oneof(space: OneOf) -> Box:
num_subspaces = len(space.spaces)
max_flatdim = max(flatdim(s) for s in space.spaces) + 1

lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])

overall_low = np.min(lows)
overall_high = np.max(highs)

low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])

dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
36 changes: 35 additions & 1 deletion gymnasium/vector/utils/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)


@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)


@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
Expand Down Expand Up @@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):


@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
def _read_text_from_shared_memory(
space: Text, shared_memory, n: int = 1
) -> tuple[str, ...]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
Expand All @@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
)


@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
)


@singledispatch
def write_to_shared_memory(
space: Space,
Expand Down Expand Up @@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
destination[index * size : (index + 1) * size],
flatten(space, values),
)


@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])

for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)
12 changes: 10 additions & 2 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1):
@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(OneOf)
@batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
Expand Down Expand Up @@ -297,6 +299,7 @@ def _concatenate_dict(
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
@concatenate.register(OneOf)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items)

Expand Down Expand Up @@ -336,7 +339,7 @@ def create_empty_array(
)


# It is possible for the some of the Box low to be greater than 0, then array is not in space
# It is possible for some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
Expand Down Expand Up @@ -402,6 +405,11 @@ def _create_empty_array_sequence(
return tuple(tuple() for _ in range(n))


@create_empty_array.register(OneOf)
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
return tuple(tuple() for _ in range(n))


@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None
Loading
Loading