Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Dec 4, 2023
1 parent 27bf126 commit 785860f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
12 changes: 5 additions & 7 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
Returns:
Tuple of the subspace's samples
"""
subspace_idx = self.np_random.integers(0, len(self.spaces))
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
Expand Down Expand Up @@ -124,18 +124,16 @@ def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
return [
[i, space.to_jsonable([sample[i] for sample in sample_n])]
for i, space in enumerate(self.spaces)
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]

def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, sample)
for space_idx, jsonable_samples in sample_n
for sample in self.spaces[space_idx].from_jsonable(jsonable_samples)
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]

def __getitem__(self, index: int) -> Space[Any]:
Expand Down
1 change: 1 addition & 0 deletions tests/spaces/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def test_seed_reproducibility(space):
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
{"space": Discrete(4)}, # Sequence
{"spaces": (Discrete(3), Discrete(5))}, # OneOf
]
assert len(SPACE_CLS) == len(SPACE_KWARGS)

Expand Down

0 comments on commit 785860f

Please sign in to comment.