Skip to content

Commit

Permalink
Fix invalid values from integer Box sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammf committed Nov 12, 2023
1 parent 6f744e1 commit 2c25426
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
19 changes: 18 additions & 1 deletion gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,24 @@ def sample(self, mask: None = None) -> NDArray[Any]:
if self.dtype.kind in ["i", "u", "b"]:
sample = np.floor(sample)

return sample.astype(self.dtype)
# clip values that would underflow/overflow
if np.issubdtype(self.dtype, np.signedinteger):
dtype_min = np.iinfo(self.dtype).min + 2
dtype_max = np.iinfo(self.dtype).max - 2
sample = sample.clip(min=dtype_min, max=dtype_max)
elif np.issubdtype(self.dtype, np.unsignedinteger):
dtype_min = np.iinfo(self.dtype).min
dtype_max = np.iinfo(self.dtype).max
sample = sample.clip(min=dtype_min, max=dtype_max)

sample = sample.astype(self.dtype)

# float64 values have lower than integer precision near int64 min/max, so clip
# again in case something has been cast to an out-of-bounds value
if self.dtype == np.int64:
sample = sample.clip(min=self.low, max=self.high)

return sample

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
43 changes: 41 additions & 2 deletions tests/spaces/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def test_infinite_space(space):
np.sign(space.low) <= np.sign(sample)
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"

# check that int bounds are bounded for everything
# but floats are unbounded for infinite
# check that int and float bounds are unbounded for infinite
if np.any(space.high != 0):
assert (
space.is_bounded("above") is False
Expand Down Expand Up @@ -282,6 +281,46 @@ def test_legacy_state_pickling():
assert b.high_repr == "1.0"


@pytest.mark.parametrize(
"space",
[
# exponential
Box(low=np.iinfo(np.uint8).max, high=np.inf, shape=(20,), dtype=np.uint8),
Box(low=np.iinfo(np.int8).max - 2, high=np.inf, shape=(20,), dtype=np.int8),
Box(low=np.iinfo(np.uint16).max, high=np.inf, shape=(20,), dtype=np.uint16),
Box(low=np.iinfo(np.int16).max - 2, high=np.inf, shape=(20,), dtype=np.int16),
Box(low=np.finfo(np.float16).max, high=np.inf, shape=(20,), dtype=np.float16),
Box(low=np.iinfo(np.uint32).max, high=np.inf, shape=(20,), dtype=np.uint32),
Box(low=np.iinfo(np.int32).max - 2, high=np.inf, shape=(20,), dtype=np.int32),
Box(low=np.finfo(np.float32).max, high=np.inf, shape=(20,), dtype=np.float32),
Box(low=np.iinfo(np.uint64).max, high=np.inf, shape=(20,), dtype=np.uint64),
Box(low=np.iinfo(np.int64).max - 2, high=np.inf, shape=(20,), dtype=np.int64),
Box(low=np.finfo(np.float64).max, high=np.inf, shape=(20,), dtype=np.float64),
# negative exponential
Box(low=-np.inf, high=np.iinfo(np.uint8).min, shape=(20,), dtype=np.uint8),
Box(low=-np.inf, high=np.iinfo(np.int8).min + 2, shape=(20,), dtype=np.int8),
Box(low=-np.inf, high=np.iinfo(np.uint16).min, shape=(20,), dtype=np.uint16),
Box(low=-np.inf, high=np.iinfo(np.int16).min + 2, shape=(20,), dtype=np.int16),
Box(low=-np.inf, high=np.finfo(np.float16).min, shape=(20,), dtype=np.float16),
Box(low=-np.inf, high=np.iinfo(np.uint32).min, shape=(20,), dtype=np.uint32),
Box(low=-np.inf, high=np.iinfo(np.int32).min + 2, shape=(20,), dtype=np.int32),
Box(low=-np.inf, high=np.finfo(np.float32).min, shape=(20,), dtype=np.float32),
Box(low=-np.inf, high=np.iinfo(np.uint64).min, shape=(20,), dtype=np.uint64),
Box(low=-np.inf, high=np.iinfo(np.int64).min + 2, shape=(20,), dtype=np.int64),
Box(low=-np.inf, high=np.finfo(np.float64).min, shape=(20,), dtype=np.float64),
],
)
def test_sample_valid(space):
space.seed(0)
sample = space.sample()
print(sample)

# check if space contains sample
assert (
sample in space
), f"Sample ({sample}) not inside space according to `space.contains()`"


def test_sample_mask():
"""Box cannot have a mask applied."""
space = Box(0, 1)
Expand Down

0 comments on commit 2c25426

Please sign in to comment.