From 35f1805e42635b518a24ebb72adf90b787f0139d Mon Sep 17 00:00:00 2001 From: ariel Date: Mon, 4 Dec 2023 23:42:49 +0100 Subject: [PATCH 01/13] Add a OneOf space for exclusive unions --- gymnasium/spaces/__init__.py | 2 + gymnasium/spaces/box.py | 6 ++ gymnasium/spaces/oneof.py | 151 +++++++++++++++++++++++++++++++++++ gymnasium/spaces/utils.py | 51 ++++++++++++ tests/spaces/test_oneof.py | 72 +++++++++++++++++ tests/spaces/test_utils.py | 8 +- tests/spaces/utils.py | 4 + 7 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 gymnasium/spaces/oneof.py create mode 100644 tests/spaces/test_oneof.py diff --git a/gymnasium/spaces/__init__.py b/gymnasium/spaces/__init__.py index f1d726f2d..0d07ee7e3 100644 --- a/gymnasium/spaces/__init__.py +++ b/gymnasium/spaces/__init__.py @@ -16,6 +16,7 @@ from gymnasium.spaces.graph import Graph, GraphInstance from gymnasium.spaces.multi_binary import MultiBinary from gymnasium.spaces.multi_discrete import MultiDiscrete +from gymnasium.spaces.oneof import OneOf from gymnasium.spaces.sequence import Sequence from gymnasium.spaces.space import Space from gymnasium.spaces.text import Text @@ -38,6 +39,7 @@ "Tuple", "Sequence", "Dict", + "OneOf", # util functions (there are more utility functions in vector/utils/spaces.py) "flatdim", "flatten_space", diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 418cb7158..36206c512 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -59,6 +59,7 @@ def __init__( shape: Sequence[int] | None = None, dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32, seed: int | np.random.Generator | None = None, + nullable: bool = False, ): r"""Constructor of :class:`Box`. @@ -144,6 +145,8 @@ def __init__( self.low_repr = _short_repr(self.low) self.high_repr = _short_repr(self.high) + self.nullable = nullable + super().__init__(self.shape, self.dtype, seed) @property @@ -290,6 +293,9 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): if not hasattr(self, "high_repr"): self.high_repr = _short_repr(self.high) + if not hasattr(self, "nullable"): + self.nullable = False + def get_precision(dtype: np.dtype) -> SupportsFloat: """Get precision of a data type.""" diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py new file mode 100644 index 000000000..402fba22f --- /dev/null +++ b/gymnasium/spaces/oneof.py @@ -0,0 +1,151 @@ +"""Implementation of a space that represents the cartesian product of other spaces.""" +from __future__ import annotations + +import collections.abc +import typing +from typing import Any, Iterable + +import numpy as np + +from gymnasium.spaces.space import Space + + +class OneOf(Space[Any]): + """An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances. + + Elements of this space are elements of one of the constituent spaces. + + Example: + >>> from gymnasium.spaces import OneOf, Box, Discrete + >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) + >>> observation_space.sample() + (0, array([-0.3991573 , 0.21649833], dtype=float32)) + """ + + def __init__( + self, + spaces: Iterable[Space[Any]], + seed: int | typing.Sequence[int] | np.random.Generator | None = None, + ): + r"""Constructor of :class:`Tuple` space. + + The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`. + + Args: + spaces (Iterable[Space]): The spaces that are involved in the cartesian product. + seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. + """ + self.spaces = tuple(spaces) + for space in self.spaces: + assert isinstance( + space, Space + ), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}" + super().__init__(None, None, seed) + + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return all(space.is_np_flattenable for space in self.spaces) + + def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: + """Seed the PRNG of this space and all subspaces. + + Depending on the type of seed, the subspaces will be seeded differently + + * ``None`` - All the subspaces will use a random initial seed + * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. + * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. + + Args: + seed: An optional list of ints or int to seed the (sub-)spaces. + """ + if isinstance(seed, collections.abc.Sequence): + assert ( + len(seed) == len(self.spaces) + 1 + ), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}" + seeds = super().seed(seed[0]) + for subseed, space in zip(seed, self.spaces): + seeds += space.seed(subseed) + elif isinstance(seed, int): + seeds = super().seed(seed) + subseeds = self.np_random.integers( + np.iinfo(np.int32).max, size=len(self.spaces) + ) + for subspace, subseed in zip(self.spaces, subseeds): + seeds += subspace.seed(int(subseed)) + elif seed is None: + seeds = super().seed(None) + for space in self.spaces: + seeds += space.seed(None) + else: + raise TypeError( + f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}" + ) + + return seeds + + def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]: + """Generates a single random sample inside this space. + + This method draws independent samples from the subspaces. + + Args: + mask: An optional tuple of optional masks for each of the subspace's samples, + expects the same number of masks as spaces + + Returns: + Tuple of the subspace's samples + """ + subspace_idx = self.np_random.integers(0, len(self.spaces)) + subspace = self.spaces[subspace_idx] + if mask is not None: + assert isinstance( + mask, tuple + ), f"Expected type of mask is tuple, actual type: {type(mask)}" + assert len(mask) == len( + self.spaces + ), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}" + + mask = mask[subspace_idx] + + return (subspace_idx, subspace.sample(mask=mask)) + + def contains(self, x: tuple[int, Any]) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + (idx, value) = x + + return isinstance(x, tuple) and self.spaces[idx].contains(value) + + def __repr__(self) -> str: + """Gives a string representation of this space.""" + return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")" + + def to_jsonable( + self, sample_n: typing.Sequence[tuple[int, Any]] + ) -> list[list[Any]]: + """Convert a batch of samples from this space to a JSONable data type.""" + # serialize as list-repr of tuple of vectors + return [ + [i, space.to_jsonable([sample[i] for sample in sample_n])] + for i, space in enumerate(self.spaces) + ] + + def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: + """Convert a JSONable data type to a batch of samples from this space.""" + return [ + (space_idx, sample) + for space_idx, jsonable_samples in sample_n + for sample in self.spaces[space_idx].from_jsonable(jsonable_samples) + ] + + def __getitem__(self, index: int) -> Space[Any]: + """Get the subspace at specific `index`.""" + return self.spaces[index] + + def __len__(self) -> int: + """Get the number of subspaces that are involved in the cartesian product.""" + return len(self.spaces) + + def __eq__(self, other: Any) -> bool: + """Check whether ``other`` is equivalent to this instance.""" + return isinstance(other, OneOf) and self.spaces == other.spaces diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 31fef9df7..9e270d980 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -23,6 +23,7 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int: return space.max_length +@flatdim.register(OneOf) +def _flatdim_oneof(space: OneOf) -> int: + return 1 + max(flatdim(s) for s in space.spaces) + + T = TypeVar("T") FlatType = Union[ NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance @@ -256,6 +262,22 @@ def _flatten_sequence( return tuple(flatten(space.feature_space, item) for item in x) +@flatten.register(OneOf) +def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]: + idx, sample = x + sub_space = space.spaces[idx] + flat_sample = flatten(sub_space, sample) + + max_flatdim = flatdim(space) - 1 # Don't include the index + if flat_sample.size < max_flatdim: + padding = np.full( + max_flatdim - flat_sample.size, np.nan, dtype=flat_sample.dtype + ) + flat_sample = np.concatenate([flat_sample, padding]) + + return np.concatenate([[idx], flat_sample]) + + @singledispatch def unflatten(space: Space[T], x: FlatType) -> T: """Unflatten a data point from a space. @@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] return tuple(unflatten(space.feature_space, item) for item in x) +@unflatten.register(OneOf) +def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]: + idx = int(x[0]) + sub_space = space.spaces[idx] + + original_size = flatdim(sub_space) + trimmed_sample = x[1 : 1 + original_size] + + return idx, unflatten(sub_space, trimmed_sample) + + @singledispatch def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: """Flatten a space into a space that is as flat as possible. @@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box: @flatten_space.register(Sequence) def _flatten_space_sequence(space: Sequence) -> Sequence: return Sequence(flatten_space(space.feature_space), stack=space.stack) + + +@flatten_space.register(OneOf) +def _flatten_space_oneof(space: OneOf) -> Box: + num_subspaces = len(space.spaces) + max_flatdim = max(flatdim(s) for s in space.spaces) + 1 + + lows = np.array([np.min(flatten_space(s).low) for s in space.spaces]) + highs = np.array([np.max(flatten_space(s).high) for s in space.spaces]) + + overall_low = np.min(lows) + overall_high = np.max(highs) + + low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)]) + high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) + + dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) + return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) diff --git a/tests/spaces/test_oneof.py b/tests/spaces/test_oneof.py new file mode 100644 index 000000000..61088768f --- /dev/null +++ b/tests/spaces/test_oneof.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest + +from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf +from gymnasium.utils.env_checker import data_equivalence + + +def test_oneof_inheritance(): + """Tests that OneOf space properly inherits and implements required methods.""" + spaces = [Discrete(5), Box(-1, 1, shape=(3,)), MultiBinary(2)] + oneof_space = OneOf(spaces) + + assert len(oneof_space) == len(spaces) + # Test indexing + for i in range(len(oneof_space)): + assert oneof_space[i] == spaces[i] + + # Test iterable + for space in oneof_space: + assert space in spaces + + +@pytest.mark.parametrize( + "spaces, seed, expected_len", + [ + ([Discrete(5), Box(-1, 1, shape=(3,))], None, 3), + ([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3), + ([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3), + ], +) +def test_oneof_seeds(spaces, seed, expected_len): + oneof_space = OneOf(spaces) + seeds = oneof_space.seed(seed) + assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds) + assert len(seeds) == expected_len + + sample1 = oneof_space.sample() + + seeds2 = oneof_space.seed(seed) + sample2 = oneof_space.sample() + + data_equivalence(seeds, seeds2) + data_equivalence(sample1, sample2) + + +@pytest.mark.parametrize( + "spaces_fn", + [ + lambda: OneOf(["abc"]), + lambda: OneOf([Box(0, 1), "abc"]), + lambda: OneOf("abc"), + ], +) +def test_bad_oneof_calls(spaces_fn): + with pytest.raises(AssertionError): + spaces_fn() + + +def test_oneof_contains(): + space = OneOf([Box(0, 1), Box(-1, 0, (2,))]) + + assert (0, np.array([0.5], dtype=np.float32)) in space + assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space + + +def test_bad_oneof_seed(): + space = OneOf([Box(0, 1), Box(0, 1)]) + with pytest.raises( + TypeError, + match="Expected seed type: list, tuple, int or None, actual type: ", + ): + space.seed(0.0) diff --git a/tests/spaces/test_utils.py b/tests/spaces/test_utils.py index ed3d7d32f..5e4698876 100644 --- a/tests/spaces/test_utils.py +++ b/tests/spaces/test_utils.py @@ -54,6 +54,11 @@ None, None, None, + None, + None, + # OneOf + 4, + 5, ] @@ -106,7 +111,8 @@ def test_flatten_space(space): @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) def test_flatten(space): """Test that a flattened sample have the `flatdim` shape.""" - flattened_sample = utils.flatten(space, space.sample()) + sample = space.sample() + flattened_sample = utils.flatten(space, sample) if space.is_np_flattenable: assert isinstance(flattened_sample, np.ndarray) diff --git a/tests/spaces/utils.py b/tests/spaces/utils.py index b8f8ba1c5..05f6a12db 100644 --- a/tests/spaces/utils.py +++ b/tests/spaces/utils.py @@ -9,6 +9,7 @@ Graph, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -108,6 +109,9 @@ Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), Sequence(Box(low=0.0, high=1.0), stack=True), Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True), + # OneOf spaces + OneOf([Discrete(3), Box(low=0.0, high=1.0)]), + OneOf([MultiBinary(2), MultiDiscrete([2, 2])]), ] TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES] From 27bf12679a0a2e0ada392b183ab797d5c6650c21 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 00:36:51 +0100 Subject: [PATCH 02/13] Add proper handling of nullable boxes --- gymnasium/spaces/box.py | 10 ++++++++-- gymnasium/spaces/utils.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 36206c512..3fc6be7ec 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -246,11 +246,17 @@ def contains(self, x: Any) -> bool: except (ValueError, TypeError): return False + bounded_below = x >= self.low + bounded_above = x <= self.high + + bounded = bounded_below & bounded_above + if self.nullable: + bounded |= np.isnan(x) + return bool( np.can_cast(x.dtype, self.dtype) and x.shape == self.shape - and np.all(x >= self.low) - and np.all(x <= self.high) + and np.all(bounded) ) def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]: diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 9e270d980..3df8db147 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -575,4 +575,4 @@ def _flatten_space_oneof(space: OneOf) -> Box: high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) - return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) + return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype, nullable=True) From 785860f048306e8312893b32fcce43ae8d44fbc3 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 00:56:11 +0100 Subject: [PATCH 03/13] Fix tests --- gymnasium/spaces/oneof.py | 12 +++++------- tests/spaces/test_spaces.py | 1 + 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py index 402fba22f..7709e7115 100644 --- a/gymnasium/spaces/oneof.py +++ b/gymnasium/spaces/oneof.py @@ -96,7 +96,7 @@ def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]: Returns: Tuple of the subspace's samples """ - subspace_idx = self.np_random.integers(0, len(self.spaces)) + subspace_idx = int(self.np_random.integers(0, len(self.spaces))) subspace = self.spaces[subspace_idx] if mask is not None: assert isinstance( @@ -124,18 +124,16 @@ def to_jsonable( self, sample_n: typing.Sequence[tuple[int, Any]] ) -> list[list[Any]]: """Convert a batch of samples from this space to a JSONable data type.""" - # serialize as list-repr of tuple of vectors return [ - [i, space.to_jsonable([sample[i] for sample in sample_n])] - for i, space in enumerate(self.spaces) + [int(i), self.spaces[i].to_jsonable([subsample])[0]] + for (i, subsample) in sample_n ] def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: """Convert a JSONable data type to a batch of samples from this space.""" return [ - (space_idx, sample) - for space_idx, jsonable_samples in sample_n - for sample in self.spaces[space_idx].from_jsonable(jsonable_samples) + (space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0]) + for space_idx, jsonable_sample in sample_n ] def __getitem__(self, index: int) -> Space[Any]: diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 82f3a84f9..139edc9a4 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -534,6 +534,7 @@ def test_seed_reproducibility(space): {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph {"space": Discrete(4)}, # Sequence + {"spaces": (Discrete(3), Discrete(5))}, # OneOf ] assert len(SPACE_CLS) == len(SPACE_KWARGS) From cf8baed6cd0957a4d2aa6c9586971609843acc49 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 01:19:16 +0100 Subject: [PATCH 04/13] Unbreak contains for box --- gymnasium/spaces/box.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 3fc6be7ec..36d8c9824 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -246,6 +246,12 @@ def contains(self, x: Any) -> bool: except (ValueError, TypeError): return False + if not np.can_cast(x.dtype, self.dtype): + return False + + if x.shape != self.shape: + return False + bounded_below = x >= self.low bounded_above = x <= self.high @@ -253,11 +259,7 @@ def contains(self, x: Any) -> bool: if self.nullable: bounded |= np.isnan(x) - return bool( - np.can_cast(x.dtype, self.dtype) - and x.shape == self.shape - and np.all(bounded) - ) + return np.all(bounded) def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]: """Convert a batch of samples from this space to a JSONable data type.""" From f753f1a873dfdb8b749622431f358919c5767bf0 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 01:31:51 +0100 Subject: [PATCH 05/13] Add support for (naive) batching --- gymnasium/vector/utils/space_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index 3d103ecdf..8d38f5704 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -23,6 +23,7 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -121,6 +122,11 @@ def _batch_space_dict(space: Dict, n: int = 1): ) +@batch_space.register(OneOf) +def _batch_space_oneof(space: OneOf, n: int = 1): + return Sequence(space) + + @batch_space.register(Graph) @batch_space.register(Text) @batch_space.register(Sequence) @@ -227,6 +233,11 @@ def _iterate_dict(space: Dict, items: dict[str, Any]): yield OrderedDict({key: value for key, value in zip(keys, item)}) +@iterate.register(Sequence) +def _iterate_sequence(space: Sequence, items: list[Any]): + yield from items + + @singledispatch def concatenate( space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray From 31061a530631f1635b19d2caebbaa0b744007c0e Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 01:37:45 +0100 Subject: [PATCH 06/13] Fix tests again --- gymnasium/spaces/box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 36d8c9824..e22a9e30d 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -259,7 +259,7 @@ def contains(self, x: Any) -> bool: if self.nullable: bounded |= np.isnan(x) - return np.all(bounded) + return bool(np.all(bounded)) def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]: """Convert a batch of samples from this space to a JSONable data type.""" From 60eb634dd7042f2d6841ae9e9682171186b05904 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 01:42:09 +0100 Subject: [PATCH 07/13] Add an attempt at concatenate create_empty_array for OneOf? --- gymnasium/vector/utils/space_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index 8d38f5704..ae317a05a 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -308,6 +308,7 @@ def _concatenate_dict( @concatenate.register(Text) @concatenate.register(Sequence) @concatenate.register(Space) +@concatenate.register(OneOf) def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]: return tuple(items) @@ -413,6 +414,11 @@ def _create_empty_array_sequence( return tuple(tuple() for _ in range(n)) +@create_empty_array.register(OneOf) +def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros): + return tuple(tuple() for _ in range(n)) + + @create_empty_array.register(Space) def _create_empty_array_custom(space, n=1, fn=np.zeros): return None From 9fcbbd5d71551c33baf95c56d067551ac44fb481 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 29 Jan 2024 17:23:09 +0000 Subject: [PATCH 08/13] proposed changes --- gymnasium/spaces/box.py | 26 ++++++-------------------- gymnasium/spaces/oneof.py | 4 ++-- gymnasium/spaces/utils.py | 4 ++-- gymnasium/vector/utils/space_utils.py | 16 +++------------- gymnasium/wrappers/utils.py | 7 ++++++- 5 files changed, 19 insertions(+), 38 deletions(-) diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index e22a9e30d..418cb7158 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -59,7 +59,6 @@ def __init__( shape: Sequence[int] | None = None, dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32, seed: int | np.random.Generator | None = None, - nullable: bool = False, ): r"""Constructor of :class:`Box`. @@ -145,8 +144,6 @@ def __init__( self.low_repr = _short_repr(self.low) self.high_repr = _short_repr(self.high) - self.nullable = nullable - super().__init__(self.shape, self.dtype, seed) @property @@ -246,20 +243,12 @@ def contains(self, x: Any) -> bool: except (ValueError, TypeError): return False - if not np.can_cast(x.dtype, self.dtype): - return False - - if x.shape != self.shape: - return False - - bounded_below = x >= self.low - bounded_above = x <= self.high - - bounded = bounded_below & bounded_above - if self.nullable: - bounded |= np.isnan(x) - - return bool(np.all(bounded)) + return bool( + np.can_cast(x.dtype, self.dtype) + and x.shape == self.shape + and np.all(x >= self.low) + and np.all(x <= self.high) + ) def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]: """Convert a batch of samples from this space to a JSONable data type.""" @@ -301,9 +290,6 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): if not hasattr(self, "high_repr"): self.high_repr = _short_repr(self.high) - if not hasattr(self, "nullable"): - self.nullable = False - def get_precision(dtype: np.dtype) -> SupportsFloat: """Get precision of a data type.""" diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py index 7709e7115..021a6c3c6 100644 --- a/gymnasium/spaces/oneof.py +++ b/gymnasium/spaces/oneof.py @@ -84,7 +84,7 @@ def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: return seeds - def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]: + def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]: """Generates a single random sample inside this space. This method draws independent samples from the subspaces. @@ -108,7 +108,7 @@ def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]: mask = mask[subspace_idx] - return (subspace_idx, subspace.sample(mask=mask)) + return subspace_idx, subspace.sample(mask=mask) def contains(self, x: tuple[int, Any]) -> bool: """Return boolean specifying if x is a valid member of this space.""" diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 3df8db147..853bfa724 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -271,7 +271,7 @@ def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]: max_flatdim = flatdim(space) - 1 # Don't include the index if flat_sample.size < max_flatdim: padding = np.full( - max_flatdim - flat_sample.size, np.nan, dtype=flat_sample.dtype + max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype ) flat_sample = np.concatenate([flat_sample, padding]) @@ -575,4 +575,4 @@ def _flatten_space_oneof(space: OneOf) -> Box: high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) - return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype, nullable=True) + return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index ae317a05a..42d3f9b16 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -122,16 +122,12 @@ def _batch_space_dict(space: Dict, n: int = 1): ) -@batch_space.register(OneOf) -def _batch_space_oneof(space: OneOf, n: int = 1): - return Sequence(space) - - @batch_space.register(Graph) @batch_space.register(Text) @batch_space.register(Sequence) +@batch_space.register(OneOf) @batch_space.register(Space) -def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1): +def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1): # Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random # Which is an issue if you are sampling actions of both the original space and the batched space batched_space = Tuple( @@ -233,11 +229,6 @@ def _iterate_dict(space: Dict, items: dict[str, Any]): yield OrderedDict({key: value for key, value in zip(keys, item)}) -@iterate.register(Sequence) -def _iterate_sequence(space: Sequence, items: list[Any]): - yield from items - - @singledispatch def concatenate( space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray @@ -348,7 +339,7 @@ def create_empty_array( ) -# It is possible for the some of the Box low to be greater than 0, then array is not in space +# It is possible for some of the Box low to be greater than 0, then array is not in space @create_empty_array.register(Box) # If the Discrete start > 0 or start + length < 0 then array is not in space @create_empty_array.register(Discrete) @@ -418,7 +409,6 @@ def _create_empty_array_sequence( def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros): return tuple(tuple() for _ in range(n)) - @create_empty_array.register(Space) def _create_empty_array_custom(space, n=1, fn=np.zeros): return None diff --git a/gymnasium/wrappers/utils.py b/gymnasium/wrappers/utils.py index fac8e1fde..e89e0dc7c 100644 --- a/gymnasium/wrappers/utils.py +++ b/gymnasium/wrappers/utils.py @@ -16,7 +16,7 @@ MultiDiscrete, Sequence, Text, - Tuple, + Tuple, OneOf, ) from gymnasium.spaces.space import T_cov @@ -145,3 +145,8 @@ def _create_graph_zero_array(space: Graph): edges = np.expand_dims(create_zero_array(space.edge_space), axis=0) edge_links = np.zeros((1, 2), dtype=np.int64) return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) + + +@create_zero_array.register(OneOf) +def _create_one_of_zero_array(space: OneOf): + return 0, create_zero_array(space.spaces[0]) From be5389db0e2d5b590d416029d16986346c756c51 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 27 Feb 2024 15:02:01 +0000 Subject: [PATCH 09/13] pre-commit --- gymnasium/vector/utils/space_utils.py | 1 + gymnasium/wrappers/utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index 42d3f9b16..a54e0aa36 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -409,6 +409,7 @@ def _create_empty_array_sequence( def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros): return tuple(tuple() for _ in range(n)) + @create_empty_array.register(Space) def _create_empty_array_custom(space, n=1, fn=np.zeros): return None diff --git a/gymnasium/wrappers/utils.py b/gymnasium/wrappers/utils.py index e89e0dc7c..e9d38bd24 100644 --- a/gymnasium/wrappers/utils.py +++ b/gymnasium/wrappers/utils.py @@ -14,9 +14,10 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Text, - Tuple, OneOf, + Tuple, ) from gymnasium.spaces.space import T_cov From ca5d6f9112b035b49f935b21356852f1f6c4b387 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 27 Feb 2024 15:26:52 +0000 Subject: [PATCH 10/13] Add shared_memory function --- gymnasium/vector/utils/shared_memory.py | 7 +++++-- tests/vector/utils/test_shared_memory.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index 6f4d9c8a1..ad2610454 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -21,7 +21,7 @@ Space, Text, Tuple, - flatten, + flatten, OneOf, ) @@ -72,6 +72,7 @@ def _create_base_shared_memory( @create_shared_memory.register(Tuple) +@create_shared_memory.register(OneOf) def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp): return tuple( create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces @@ -147,6 +148,7 @@ def _read_base_from_shared_memory( @read_from_shared_memory.register(Tuple) +@read_from_shared_memory.register(OneOf) def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1): return tuple( read_from_shared_memory(subspace, memory, n=n) @@ -165,7 +167,7 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): @read_from_shared_memory.register(Text) -def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]: +def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str, ...]: data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( (n, space.max_length) ) @@ -230,6 +232,7 @@ def _write_base_to_shared_memory( @write_to_shared_memory.register(Tuple) +@write_to_shared_memory.register(OneOf) def _write_tuple_to_shared_memory( space: Tuple, index: int, values: tuple[Any, ...], shared_memory ): diff --git a/tests/vector/utils/test_shared_memory.py b/tests/vector/utils/test_shared_memory.py index c4b732c81..354576248 100644 --- a/tests/vector/utils/test_shared_memory.py +++ b/tests/vector/utils/test_shared_memory.py @@ -22,7 +22,7 @@ "ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"] ) def test_shared_memory_create_read_write(space, num, ctx): - """Test the shared memory functions, create, read and write for all of the testing spaces.""" + """Test the shared memory functions, create, read and write for all testing spaces.""" if ctx not in mp.get_all_start_methods(): pytest.skip( f"Multiprocessing start method {ctx} not available on this platform." @@ -41,7 +41,9 @@ def test_shared_memory_create_read_write(space, num, ctx): read_samples = read_from_shared_memory(space, shared_memory, n=num) for read_sample, sample in zip(read_samples, samples): - data_equivalence(read_sample, sample) + assert sample in space + assert read_sample in space + assert data_equivalence(read_sample, sample) def test_custom_space(): From d871c04ccb454750e1a4d708a3b40917cdf52733 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 27 Feb 2024 16:27:22 +0000 Subject: [PATCH 11/13] Add shared memory for OneOf --- gymnasium/vector/utils/shared_memory.py | 50 +++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index ad2610454..ec3789a73 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -17,11 +17,12 @@ Graph, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, Tuple, - flatten, OneOf, + flatten, ) @@ -72,7 +73,6 @@ def _create_base_shared_memory( @create_shared_memory.register(Tuple) -@create_shared_memory.register(OneOf) def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp): return tuple( create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces @@ -94,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp): return ctx.Array(np.dtype(np.int32).char, n * space.max_length) +@create_shared_memory.register(OneOf) +def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp): + return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space) + + @create_shared_memory.register(Graph) @create_shared_memory.register(Sequence) def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): @@ -148,26 +153,32 @@ def _read_base_from_shared_memory( @read_from_shared_memory.register(Tuple) -@read_from_shared_memory.register(OneOf) def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1): - return tuple( + subspace_samples = tuple( read_from_shared_memory(subspace, memory, n=n) for (memory, subspace) in zip(shared_memory, space.spaces) ) + return tuple(zip(*subspace_samples)) @read_from_shared_memory.register(Dict) def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): - return OrderedDict( + subspace_samples = OrderedDict( [ (key, read_from_shared_memory(subspace, shared_memory[key], n=n)) for (key, subspace) in space.spaces.items() ] ) + return tuple( + OrderedDict({key: subspace_samples[key][i] for key in space.keys()}) + for i in range(n) + ) @read_from_shared_memory.register(Text) -def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str, ...]: +def _read_text_from_shared_memory( + space: Text, shared_memory, n: int = 1 +) -> tuple[str, ...]: data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( (n, space.max_length) ) @@ -184,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup ) +@read_from_shared_memory.register(OneOf) +def _read_one_of_from_shared_memory( + space: OneOf, shared_memory, n: int = 1 +) -> tuple[Any, ...]: + sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype) + subspace_samples = tuple( + read_from_shared_memory(subspace, memory, n=n) + for (memory, subspace) in zip(shared_memory[1:], space.spaces) + ) + return tuple( + (index, sample[index]) + for index, sample in zip(sample_indexes, subspace_samples) + ) + + @singledispatch def write_to_shared_memory( space: Space, @@ -232,7 +258,6 @@ def _write_base_to_shared_memory( @write_to_shared_memory.register(Tuple) -@write_to_shared_memory.register(OneOf) def _write_tuple_to_shared_memory( space: Tuple, index: int, values: tuple[Any, ...], shared_memory ): @@ -256,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me destination[index * size : (index + 1) * size], flatten(space, values), ) + + +@write_to_shared_memory.register(OneOf) +def _write_oneof_to_shared_memory( + space: OneOf, index: int, values: tuple[int, Any, ...], shared_memory +): + destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32) + np.copyto(destination[index : index + 1], values[0]) + + for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces): + write_to_shared_memory(subspace, index, value, memory) From 54d79e8e7a9088d6b80190be3a238d12a38675bd Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 27 Feb 2024 17:09:47 +0000 Subject: [PATCH 12/13] Fix docstring and type hint --- gymnasium/spaces/oneof.py | 2 +- gymnasium/vector/utils/shared_memory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py index 021a6c3c6..097372518 100644 --- a/gymnasium/spaces/oneof.py +++ b/gymnasium/spaces/oneof.py @@ -19,7 +19,7 @@ class OneOf(Space[Any]): >>> from gymnasium.spaces import OneOf, Box, Discrete >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) >>> observation_space.sample() - (0, array([-0.3991573 , 0.21649833], dtype=float32)) + (1, array([-0.3991573 , 0.21649833], dtype=float32)) """ def __init__( diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index ec3789a73..f9ee50f65 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -285,7 +285,7 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me @write_to_shared_memory.register(OneOf) def _write_oneof_to_shared_memory( - space: OneOf, index: int, values: tuple[int, Any, ...], shared_memory + space: OneOf, index: int, values: tuple[Any, ...], shared_memory ): destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32) np.copyto(destination[index : index + 1], values[0]) From 6513099806d3bd9661d667925d6b4b7a78bf8ba5 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 11 Mar 2024 12:10:41 +0000 Subject: [PATCH 13/13] Update the documentation --- docs/api/spaces.md | 3 ++- docs/api/spaces/composite.md | 5 +++++ gymnasium/spaces/oneof.py | 13 +++++++++++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/docs/api/spaces.md b/docs/api/spaces.md index 9f749a66a..014daec15 100644 --- a/docs/api/spaces.md +++ b/docs/api/spaces.md @@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori * :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces * :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces * :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions -* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values. +* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values +* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces ``` ## Utility functions diff --git a/docs/api/spaces/composite.md b/docs/api/spaces/composite.md index b43a5b0f5..72d8aee27 100644 --- a/docs/api/spaces/composite.md +++ b/docs/api/spaces/composite.md @@ -21,4 +21,9 @@ .. automethod:: gymnasium.spaces.Graph.sample .. automethod:: gymnasium.spaces.Graph.seed + +.. autoclass:: gymnasium.spaces.OneOf + + .. automethod:: gymnasium.spaces.OneOf.sample + .. automethod:: gymnasium.spaces.OneOf.seed ``` diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py index 097372518..d88f0b130 100644 --- a/gymnasium/spaces/oneof.py +++ b/gymnasium/spaces/oneof.py @@ -18,8 +18,16 @@ class OneOf(Space[Any]): Example: >>> from gymnasium.spaces import OneOf, Box, Discrete >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) - >>> observation_space.sample() + >>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box (1, array([-0.3991573 , 0.21649833], dtype=float32)) + >>> observation_space.sample() # this time the Discrete space was sampled as index=0 + (0, 0) + >>> observation_space[0] + Discrete(2) + >>> observation_space[1] + Box(-1.0, 1.0, (2,), float32) + >>> len(observation_space) + 2 """ def __init__( @@ -27,7 +35,7 @@ def __init__( spaces: Iterable[Space[Any]], seed: int | typing.Sequence[int] | np.random.Generator | None = None, ): - r"""Constructor of :class:`Tuple` space. + r"""Constructor of :class:`OneOf` space. The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`. @@ -36,6 +44,7 @@ def __init__( seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. """ self.spaces = tuple(spaces) + assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported." for space in self.spaces: assert isinstance( space, Space