diff --git a/docs/api/spaces.md b/docs/api/spaces.md index 9f749a66a..014daec15 100644 --- a/docs/api/spaces.md +++ b/docs/api/spaces.md @@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori * :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces * :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces * :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions -* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values. +* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values +* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces ``` ## Utility functions diff --git a/docs/api/spaces/composite.md b/docs/api/spaces/composite.md index b43a5b0f5..72d8aee27 100644 --- a/docs/api/spaces/composite.md +++ b/docs/api/spaces/composite.md @@ -21,4 +21,9 @@ .. automethod:: gymnasium.spaces.Graph.sample .. automethod:: gymnasium.spaces.Graph.seed + +.. autoclass:: gymnasium.spaces.OneOf + + .. automethod:: gymnasium.spaces.OneOf.sample + .. automethod:: gymnasium.spaces.OneOf.seed ``` diff --git a/gymnasium/spaces/__init__.py b/gymnasium/spaces/__init__.py index f1d726f2d..0d07ee7e3 100644 --- a/gymnasium/spaces/__init__.py +++ b/gymnasium/spaces/__init__.py @@ -16,6 +16,7 @@ from gymnasium.spaces.graph import Graph, GraphInstance from gymnasium.spaces.multi_binary import MultiBinary from gymnasium.spaces.multi_discrete import MultiDiscrete +from gymnasium.spaces.oneof import OneOf from gymnasium.spaces.sequence import Sequence from gymnasium.spaces.space import Space from gymnasium.spaces.text import Text @@ -38,6 +39,7 @@ "Tuple", "Sequence", "Dict", + "OneOf", # util functions (there are more utility functions in vector/utils/spaces.py) "flatdim", "flatten_space", diff --git a/gymnasium/spaces/oneof.py b/gymnasium/spaces/oneof.py new file mode 100644 index 000000000..d88f0b130 --- /dev/null +++ b/gymnasium/spaces/oneof.py @@ -0,0 +1,158 @@ +"""Implementation of a space that represents the cartesian product of other spaces.""" +from __future__ import annotations + +import collections.abc +import typing +from typing import Any, Iterable + +import numpy as np + +from gymnasium.spaces.space import Space + + +class OneOf(Space[Any]): + """An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances. + + Elements of this space are elements of one of the constituent spaces. + + Example: + >>> from gymnasium.spaces import OneOf, Box, Discrete + >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) + >>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box + (1, array([-0.3991573 , 0.21649833], dtype=float32)) + >>> observation_space.sample() # this time the Discrete space was sampled as index=0 + (0, 0) + >>> observation_space[0] + Discrete(2) + >>> observation_space[1] + Box(-1.0, 1.0, (2,), float32) + >>> len(observation_space) + 2 + """ + + def __init__( + self, + spaces: Iterable[Space[Any]], + seed: int | typing.Sequence[int] | np.random.Generator | None = None, + ): + r"""Constructor of :class:`OneOf` space. + + The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`. + + Args: + spaces (Iterable[Space]): The spaces that are involved in the cartesian product. + seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling. + """ + self.spaces = tuple(spaces) + assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported." + for space in self.spaces: + assert isinstance( + space, Space + ), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}" + super().__init__(None, None, seed) + + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return all(space.is_np_flattenable for space in self.spaces) + + def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: + """Seed the PRNG of this space and all subspaces. + + Depending on the type of seed, the subspaces will be seeded differently + + * ``None`` - All the subspaces will use a random initial seed + * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. + * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. + + Args: + seed: An optional list of ints or int to seed the (sub-)spaces. + """ + if isinstance(seed, collections.abc.Sequence): + assert ( + len(seed) == len(self.spaces) + 1 + ), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}" + seeds = super().seed(seed[0]) + for subseed, space in zip(seed, self.spaces): + seeds += space.seed(subseed) + elif isinstance(seed, int): + seeds = super().seed(seed) + subseeds = self.np_random.integers( + np.iinfo(np.int32).max, size=len(self.spaces) + ) + for subspace, subseed in zip(self.spaces, subseeds): + seeds += subspace.seed(int(subseed)) + elif seed is None: + seeds = super().seed(None) + for space in self.spaces: + seeds += space.seed(None) + else: + raise TypeError( + f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}" + ) + + return seeds + + def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]: + """Generates a single random sample inside this space. + + This method draws independent samples from the subspaces. + + Args: + mask: An optional tuple of optional masks for each of the subspace's samples, + expects the same number of masks as spaces + + Returns: + Tuple of the subspace's samples + """ + subspace_idx = int(self.np_random.integers(0, len(self.spaces))) + subspace = self.spaces[subspace_idx] + if mask is not None: + assert isinstance( + mask, tuple + ), f"Expected type of mask is tuple, actual type: {type(mask)}" + assert len(mask) == len( + self.spaces + ), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}" + + mask = mask[subspace_idx] + + return subspace_idx, subspace.sample(mask=mask) + + def contains(self, x: tuple[int, Any]) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + (idx, value) = x + + return isinstance(x, tuple) and self.spaces[idx].contains(value) + + def __repr__(self) -> str: + """Gives a string representation of this space.""" + return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")" + + def to_jsonable( + self, sample_n: typing.Sequence[tuple[int, Any]] + ) -> list[list[Any]]: + """Convert a batch of samples from this space to a JSONable data type.""" + return [ + [int(i), self.spaces[i].to_jsonable([subsample])[0]] + for (i, subsample) in sample_n + ] + + def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: + """Convert a JSONable data type to a batch of samples from this space.""" + return [ + (space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0]) + for space_idx, jsonable_sample in sample_n + ] + + def __getitem__(self, index: int) -> Space[Any]: + """Get the subspace at specific `index`.""" + return self.spaces[index] + + def __len__(self) -> int: + """Get the number of subspaces that are involved in the cartesian product.""" + return len(self.spaces) + + def __eq__(self, other: Any) -> bool: + """Check whether ``other`` is equivalent to this instance.""" + return isinstance(other, OneOf) and self.spaces == other.spaces diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 31fef9df7..853bfa724 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -23,6 +23,7 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int: return space.max_length +@flatdim.register(OneOf) +def _flatdim_oneof(space: OneOf) -> int: + return 1 + max(flatdim(s) for s in space.spaces) + + T = TypeVar("T") FlatType = Union[ NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance @@ -256,6 +262,22 @@ def _flatten_sequence( return tuple(flatten(space.feature_space, item) for item in x) +@flatten.register(OneOf) +def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]: + idx, sample = x + sub_space = space.spaces[idx] + flat_sample = flatten(sub_space, sample) + + max_flatdim = flatdim(space) - 1 # Don't include the index + if flat_sample.size < max_flatdim: + padding = np.full( + max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype + ) + flat_sample = np.concatenate([flat_sample, padding]) + + return np.concatenate([[idx], flat_sample]) + + @singledispatch def unflatten(space: Space[T], x: FlatType) -> T: """Unflatten a data point from a space. @@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] return tuple(unflatten(space.feature_space, item) for item in x) +@unflatten.register(OneOf) +def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]: + idx = int(x[0]) + sub_space = space.spaces[idx] + + original_size = flatdim(sub_space) + trimmed_sample = x[1 : 1 + original_size] + + return idx, unflatten(sub_space, trimmed_sample) + + @singledispatch def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: """Flatten a space into a space that is as flat as possible. @@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box: @flatten_space.register(Sequence) def _flatten_space_sequence(space: Sequence) -> Sequence: return Sequence(flatten_space(space.feature_space), stack=space.stack) + + +@flatten_space.register(OneOf) +def _flatten_space_oneof(space: OneOf) -> Box: + num_subspaces = len(space.spaces) + max_flatdim = max(flatdim(s) for s in space.spaces) + 1 + + lows = np.array([np.min(flatten_space(s).low) for s in space.spaces]) + highs = np.array([np.max(flatten_space(s).high) for s in space.spaces]) + + overall_low = np.min(lows) + overall_high = np.max(highs) + + low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)]) + high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) + + dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) + return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index c159f7093..f9ee50f65 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -17,6 +17,7 @@ Graph, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp): return ctx.Array(np.dtype(np.int32).char, n * space.max_length) +@create_shared_memory.register(OneOf) +def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp): + return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space) + + @create_shared_memory.register(Graph) @create_shared_memory.register(Sequence) def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): @@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): @read_from_shared_memory.register(Text) -def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]: +def _read_text_from_shared_memory( + space: Text, shared_memory, n: int = 1 +) -> tuple[str, ...]: data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( (n, space.max_length) ) @@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup ) +@read_from_shared_memory.register(OneOf) +def _read_one_of_from_shared_memory( + space: OneOf, shared_memory, n: int = 1 +) -> tuple[Any, ...]: + sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype) + subspace_samples = tuple( + read_from_shared_memory(subspace, memory, n=n) + for (memory, subspace) in zip(shared_memory[1:], space.spaces) + ) + return tuple( + (index, sample[index]) + for index, sample in zip(sample_indexes, subspace_samples) + ) + + @singledispatch def write_to_shared_memory( space: Space, @@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me destination[index * size : (index + 1) * size], flatten(space, values), ) + + +@write_to_shared_memory.register(OneOf) +def _write_oneof_to_shared_memory( + space: OneOf, index: int, values: tuple[Any, ...], shared_memory +): + destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32) + np.copyto(destination[index : index + 1], values[0]) + + for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces): + write_to_shared_memory(subspace, index, value, memory) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index 3d103ecdf..a54e0aa36 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -23,6 +23,7 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1): @batch_space.register(Graph) @batch_space.register(Text) @batch_space.register(Sequence) +@batch_space.register(OneOf) @batch_space.register(Space) -def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1): +def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1): # Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random # Which is an issue if you are sampling actions of both the original space and the batched space batched_space = Tuple( @@ -297,6 +299,7 @@ def _concatenate_dict( @concatenate.register(Text) @concatenate.register(Sequence) @concatenate.register(Space) +@concatenate.register(OneOf) def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]: return tuple(items) @@ -336,7 +339,7 @@ def create_empty_array( ) -# It is possible for the some of the Box low to be greater than 0, then array is not in space +# It is possible for some of the Box low to be greater than 0, then array is not in space @create_empty_array.register(Box) # If the Discrete start > 0 or start + length < 0 then array is not in space @create_empty_array.register(Discrete) @@ -402,6 +405,11 @@ def _create_empty_array_sequence( return tuple(tuple() for _ in range(n)) +@create_empty_array.register(OneOf) +def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros): + return tuple(tuple() for _ in range(n)) + + @create_empty_array.register(Space) def _create_empty_array_custom(space, n=1, fn=np.zeros): return None diff --git a/gymnasium/wrappers/utils.py b/gymnasium/wrappers/utils.py index fac8e1fde..e9d38bd24 100644 --- a/gymnasium/wrappers/utils.py +++ b/gymnasium/wrappers/utils.py @@ -14,6 +14,7 @@ GraphInstance, MultiBinary, MultiDiscrete, + OneOf, Sequence, Text, Tuple, @@ -145,3 +146,8 @@ def _create_graph_zero_array(space: Graph): edges = np.expand_dims(create_zero_array(space.edge_space), axis=0) edge_links = np.zeros((1, 2), dtype=np.int64) return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) + + +@create_zero_array.register(OneOf) +def _create_one_of_zero_array(space: OneOf): + return 0, create_zero_array(space.spaces[0]) diff --git a/tests/spaces/test_oneof.py b/tests/spaces/test_oneof.py new file mode 100644 index 000000000..61088768f --- /dev/null +++ b/tests/spaces/test_oneof.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest + +from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf +from gymnasium.utils.env_checker import data_equivalence + + +def test_oneof_inheritance(): + """Tests that OneOf space properly inherits and implements required methods.""" + spaces = [Discrete(5), Box(-1, 1, shape=(3,)), MultiBinary(2)] + oneof_space = OneOf(spaces) + + assert len(oneof_space) == len(spaces) + # Test indexing + for i in range(len(oneof_space)): + assert oneof_space[i] == spaces[i] + + # Test iterable + for space in oneof_space: + assert space in spaces + + +@pytest.mark.parametrize( + "spaces, seed, expected_len", + [ + ([Discrete(5), Box(-1, 1, shape=(3,))], None, 3), + ([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3), + ([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3), + ], +) +def test_oneof_seeds(spaces, seed, expected_len): + oneof_space = OneOf(spaces) + seeds = oneof_space.seed(seed) + assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds) + assert len(seeds) == expected_len + + sample1 = oneof_space.sample() + + seeds2 = oneof_space.seed(seed) + sample2 = oneof_space.sample() + + data_equivalence(seeds, seeds2) + data_equivalence(sample1, sample2) + + +@pytest.mark.parametrize( + "spaces_fn", + [ + lambda: OneOf(["abc"]), + lambda: OneOf([Box(0, 1), "abc"]), + lambda: OneOf("abc"), + ], +) +def test_bad_oneof_calls(spaces_fn): + with pytest.raises(AssertionError): + spaces_fn() + + +def test_oneof_contains(): + space = OneOf([Box(0, 1), Box(-1, 0, (2,))]) + + assert (0, np.array([0.5], dtype=np.float32)) in space + assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space + + +def test_bad_oneof_seed(): + space = OneOf([Box(0, 1), Box(0, 1)]) + with pytest.raises( + TypeError, + match="Expected seed type: list, tuple, int or None, actual type: ", + ): + space.seed(0.0) diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 82f3a84f9..139edc9a4 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -534,6 +534,7 @@ def test_seed_reproducibility(space): {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph {"space": Discrete(4)}, # Sequence + {"spaces": (Discrete(3), Discrete(5))}, # OneOf ] assert len(SPACE_CLS) == len(SPACE_KWARGS) diff --git a/tests/spaces/test_utils.py b/tests/spaces/test_utils.py index ed3d7d32f..5e4698876 100644 --- a/tests/spaces/test_utils.py +++ b/tests/spaces/test_utils.py @@ -54,6 +54,11 @@ None, None, None, + None, + None, + # OneOf + 4, + 5, ] @@ -106,7 +111,8 @@ def test_flatten_space(space): @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) def test_flatten(space): """Test that a flattened sample have the `flatdim` shape.""" - flattened_sample = utils.flatten(space, space.sample()) + sample = space.sample() + flattened_sample = utils.flatten(space, sample) if space.is_np_flattenable: assert isinstance(flattened_sample, np.ndarray) diff --git a/tests/spaces/utils.py b/tests/spaces/utils.py index b8f8ba1c5..05f6a12db 100644 --- a/tests/spaces/utils.py +++ b/tests/spaces/utils.py @@ -9,6 +9,7 @@ Graph, MultiBinary, MultiDiscrete, + OneOf, Sequence, Space, Text, @@ -108,6 +109,9 @@ Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), Sequence(Box(low=0.0, high=1.0), stack=True), Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True), + # OneOf spaces + OneOf([Discrete(3), Box(low=0.0, high=1.0)]), + OneOf([MultiBinary(2), MultiDiscrete([2, 2])]), ] TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]