Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a OneOf space for exclusive unions #812

Merged
merged 15 commits into from
Mar 11, 2024
Merged
2 changes: 2 additions & 0 deletions gymnasium/spaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.spaces.graph import Graph, GraphInstance
from gymnasium.spaces.multi_binary import MultiBinary
from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.oneof import OneOf
from gymnasium.spaces.sequence import Sequence
from gymnasium.spaces.space import Space
from gymnasium.spaces.text import Text
Expand All @@ -38,6 +39,7 @@
"Tuple",
"Sequence",
"Dict",
"OneOf",
# util functions (there are more utility functions in vector/utils/spaces.py)
"flatdim",
"flatten_space",
Expand Down
26 changes: 20 additions & 6 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
nullable: bool = False,
):
r"""Constructor of :class:`Box`.

Expand Down Expand Up @@ -144,6 +145,8 @@ def __init__(
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)

self.nullable = nullable

super().__init__(self.shape, self.dtype, seed)

@property
Expand Down Expand Up @@ -243,12 +246,20 @@ def contains(self, x: Any) -> bool:
except (ValueError, TypeError):
return False

return bool(
np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape
and np.all(x >= self.low)
and np.all(x <= self.high)
)
if not np.can_cast(x.dtype, self.dtype):
return False

if x.shape != self.shape:
return False

bounded_below = x >= self.low
bounded_above = x <= self.high

bounded = bounded_below & bounded_above
if self.nullable:
bounded |= np.isnan(x)

return bool(np.all(bounded))

def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
"""Convert a batch of samples from this space to a JSONable data type."""
Expand Down Expand Up @@ -290,6 +301,9 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
if not hasattr(self, "high_repr"):
self.high_repr = _short_repr(self.high)

if not hasattr(self, "nullable"):
self.nullable = False


def get_precision(dtype: np.dtype) -> SupportsFloat:
"""Get precision of a data type."""
Expand Down
149 changes: 149 additions & 0 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations

import collections.abc
import typing
from typing import Any, Iterable

import numpy as np

from gymnasium.spaces.space import Space


class OneOf(Space[Any]):
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.

Elements of this space are elements of one of the constituent spaces.

Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample()
(0, array([-0.3991573 , 0.21649833], dtype=float32))
"""

def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`Tuple` space.

The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.

Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
for space in self.spaces:
assert isinstance(
space, Space
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed)

@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)

def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.

Depending on the type of seed, the subspaces will be seeded differently

* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.

Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)

return seeds

def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
"""Generates a single random sample inside this space.

This method draws independent samples from the subspaces.

Args:
mask: An optional tuple of optional masks for each of the subspace's samples,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support masking of the subspace index? This adds more complexity that might not be needed

expects the same number of masks as spaces

Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"

mask = mask[subspace_idx]

return (subspace_idx, subspace.sample(mask=mask))

def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x

return isinstance(x, tuple) and self.spaces[idx].contains(value)

def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]

def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]

def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]

def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)

def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, OneOf) and self.spaces == other.spaces
51 changes: 51 additions & 0 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
return space.max_length


@flatdim.register(OneOf)
def _flatdim_oneof(space: OneOf) -> int:
return 1 + max(flatdim(s) for s in space.spaces)


T = TypeVar("T")
FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
Expand Down Expand Up @@ -256,6 +262,22 @@ def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x)


@flatten.register(OneOf)
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
idx, sample = x
sub_space = space.spaces[idx]
flat_sample = flatten(sub_space, sample)

max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, np.nan, dtype=flat_sample.dtype
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
)
flat_sample = np.concatenate([flat_sample, padding])

return np.concatenate([[idx], flat_sample])


@singledispatch
def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space.
Expand Down Expand Up @@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
return tuple(unflatten(space.feature_space, item) for item in x)


@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
sub_space = space.spaces[idx]

original_size = flatdim(sub_space)
trimmed_sample = x[1 : 1 + original_size]

return idx, unflatten(sub_space, trimmed_sample)


@singledispatch
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible.
Expand Down Expand Up @@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space), stack=space.stack)


@flatten_space.register(OneOf)
def _flatten_space_oneof(space: OneOf) -> Box:
num_subspaces = len(space.spaces)
max_flatdim = max(flatdim(s) for s in space.spaces) + 1

lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])

overall_low = np.min(lows)
overall_high = np.max(highs)

low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])

dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype, nullable=True)
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 17 additions & 0 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -121,6 +122,11 @@ def _batch_space_dict(space: Dict, n: int = 1):
)


@batch_space.register(OneOf)
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
def _batch_space_oneof(space: OneOf, n: int = 1):
return Sequence(space)


@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
Expand Down Expand Up @@ -227,6 +233,11 @@ def _iterate_dict(space: Dict, items: dict[str, Any]):
yield OrderedDict({key: value for key, value in zip(keys, item)})


@iterate.register(Sequence)
def _iterate_sequence(space: Sequence, items: list[Any]):
yield from items


@singledispatch
def concatenate(
space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray
Expand Down Expand Up @@ -297,6 +308,7 @@ def _concatenate_dict(
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
@concatenate.register(OneOf)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items)

Expand Down Expand Up @@ -402,6 +414,11 @@ def _create_empty_array_sequence(
return tuple(tuple() for _ in range(n))


@create_empty_array.register(OneOf)
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
return tuple(tuple() for _ in range(n))


@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None
Loading
Loading