Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a OneOf space for exclusive unions #812

Merged
merged 15 commits into from
Mar 11, 2024
Merged
26 changes: 6 additions & 20 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
nullable: bool = False,
):
r"""Constructor of :class:`Box`.

Expand Down Expand Up @@ -145,8 +144,6 @@ def __init__(
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)

self.nullable = nullable

super().__init__(self.shape, self.dtype, seed)

@property
Expand Down Expand Up @@ -246,20 +243,12 @@ def contains(self, x: Any) -> bool:
except (ValueError, TypeError):
return False

if not np.can_cast(x.dtype, self.dtype):
return False

if x.shape != self.shape:
return False

bounded_below = x >= self.low
bounded_above = x <= self.high

bounded = bounded_below & bounded_above
if self.nullable:
bounded |= np.isnan(x)

return bool(np.all(bounded))
return bool(
np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape
and np.all(x >= self.low)
and np.all(x <= self.high)
)

def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
"""Convert a batch of samples from this space to a JSONable data type."""
Expand Down Expand Up @@ -301,9 +290,6 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
if not hasattr(self, "high_repr"):
self.high_repr = _short_repr(self.high)

if not hasattr(self, "nullable"):
self.nullable = False


def get_precision(dtype: np.dtype) -> SupportsFloat:
"""Get precision of a data type."""
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class OneOf(Space[Any]):
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample()
(0, array([-0.3991573 , 0.21649833], dtype=float32))
(1, array([-0.3991573 , 0.21649833], dtype=float32))
"""

def __init__(
Expand Down Expand Up @@ -84,7 +84,7 @@ def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:

return seeds

def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.

This method draws independent samples from the subspaces.
Expand All @@ -108,7 +108,7 @@ def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:

mask = mask[subspace_idx]

return (subspace_idx, subspace.sample(mask=mask))
return subspace_idx, subspace.sample(mask=mask)

def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, np.nan, dtype=flat_sample.dtype
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now wondering if it might be better to use np.empty instead if we're not using a dedicated placeholder value.

Either way, the array will be filled with some throwaway data, but with empty we're not pretending that this data has any actual meaning.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a quick check and this fails test_flat_space_contains_flat_points as the flattened version might not be contained within the flattened space.
But this is just a weird artifact of flatten rather than anything actually incorrect with the approach

)
flat_sample = np.concatenate([flat_sample, padding])

Expand Down Expand Up @@ -575,4 +575,4 @@ def _flatten_space_oneof(space: OneOf) -> Box:
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])

dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype, nullable=True)
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
36 changes: 35 additions & 1 deletion gymnasium/vector/utils/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Expand Down Expand Up @@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)


@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)


@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
Expand Down Expand Up @@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):


@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
def _read_text_from_shared_memory(
space: Text, shared_memory, n: int = 1
) -> tuple[str, ...]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
Expand All @@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
)


@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
)


@singledispatch
def write_to_shared_memory(
space: Space,
Expand Down Expand Up @@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
destination[index * size : (index + 1) * size],
flatten(space, values),
)


@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])

for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)
15 changes: 3 additions & 12 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,12 @@ def _batch_space_dict(space: Dict, n: int = 1):
)


@batch_space.register(OneOf)
def _batch_space_oneof(space: OneOf, n: int = 1):
return Sequence(space)


@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(OneOf)
@batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
Expand Down Expand Up @@ -233,11 +229,6 @@ def _iterate_dict(space: Dict, items: dict[str, Any]):
yield OrderedDict({key: value for key, value in zip(keys, item)})


@iterate.register(Sequence)
def _iterate_sequence(space: Sequence, items: list[Any]):
yield from items


@singledispatch
def concatenate(
space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray
Expand Down Expand Up @@ -348,7 +339,7 @@ def create_empty_array(
)


# It is possible for the some of the Box low to be greater than 0, then array is not in space
# It is possible for some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
Expand Down
6 changes: 6 additions & 0 deletions gymnasium/wrappers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Text,
Tuple,
Expand Down Expand Up @@ -145,3 +146,8 @@ def _create_graph_zero_array(space: Graph):
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
edge_links = np.zeros((1, 2), dtype=np.int64)
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)


@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])
Loading