diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 418cb7158..9b177c642 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -10,7 +10,7 @@ from gymnasium.spaces.space import Space -def _short_repr(arr: NDArray[Any]) -> str: +def array_short_repr(arr: NDArray[Any]) -> str: """Create a shortened string representation of a numpy array. If arr is a multiple of the all-ones vector, return a string representation of the multiplier. @@ -28,7 +28,7 @@ def _short_repr(arr: NDArray[Any]) -> str: def is_float_integer(var: Any) -> bool: - """Checks if a variable is an integer or float.""" + """Checks if a scalar variable is an integer or float (does not include bool).""" return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating) @@ -80,71 +80,231 @@ def __init__( ValueError: If no shape information is provided (shape is None, low is None and high is None) then a value error is raised. """ - assert ( - dtype is not None - ), "Box dtype must be explicitly provided, cannot be None." + # determine dtype + if dtype is None: + raise ValueError("Box dtype must be explicitly provided, cannot be None.") self.dtype = np.dtype(dtype) - # determine shape if it isn't provided directly + # * check that dtype is an accepted dtype + if not ( + np.issubdtype(self.dtype, np.integer) + or np.issubdtype(self.dtype, np.floating) + or self.dtype == np.bool_ + ): + raise ValueError( + f"Invalid Box dtype ({self.dtype}), must be an integer, floating, or bool dtype" + ) + + # determine shape if shape is not None: - assert all( - np.issubdtype(type(dim), np.integer) for dim in shape - ), f"Expected all shape elements to be an integer, actual type: {tuple(type(dim) for dim in shape)}" - shape = tuple(int(dim) for dim in shape) # This changes any np types to int + if not isinstance(shape, Iterable): + raise TypeError( + f"Expected Box shape to be an iterable, actual type={type(shape)}" + ) + elif not all(np.issubdtype(type(dim), np.integer) for dim in shape): + raise TypeError( + f"Expected all Box shape elements to be integer, actual type={tuple(type(dim) for dim in shape)}" + ) + + # Casts the `shape` argument to tuple[int, ...] (otherwise dim can `np.int64`) + shape = tuple(int(dim) for dim in shape) + elif isinstance(low, np.ndarray) and isinstance(high, np.ndarray): + if low.shape != high.shape: + raise ValueError( + f"Box low.shape and high.shape don't match, low.shape={low.shape}, high.shape={high.shape}" + ) + shape = low.shape elif isinstance(low, np.ndarray): shape = low.shape elif isinstance(high, np.ndarray): shape = high.shape elif is_float_integer(low) and is_float_integer(high): - shape = (1,) + shape = (1,) # low and high are scalars else: raise ValueError( - f"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: {type(low)}, high: {type(high)}" + "Box shape is not specified, therefore inferred from low and high. Expected low and high to be np.ndarray, integer, or float." + f"Actual types low={type(low)}, high={type(high)}" ) + self._shape: tuple[int, ...] = shape - # Capture the boundedness information before replacing np.inf with get_inf - _low = np.full(shape, low, dtype=float) if is_float_integer(low) else low - self.bounded_below: NDArray[np.bool_] = -np.inf < _low - - _high = np.full(shape, high, dtype=float) if is_float_integer(high) else high - self.bounded_above: NDArray[np.bool_] = np.inf > _high + # Cast scalar values to `np.ndarray` and capture the boundedness information + # disallowed cases + # * out of range - this must be done before casting to low and high otherwise, the value is within dtype and cannot be out of range + # * nan - must be done beforehand as int dtype can cast `nan` to another value + # * unsign int inf and -inf - special case that is disallowed + + if self.dtype == np.bool_: + dtype_min, dtype_max = 0, 1 + elif np.issubdtype(self.dtype, np.floating): + dtype_min = float(np.finfo(self.dtype).min) + dtype_max = float(np.finfo(self.dtype).max) + else: + dtype_min = int(np.iinfo(self.dtype).min) + dtype_max = int(np.iinfo(self.dtype).max) - low = _broadcast(low, self.dtype, shape) - high = _broadcast(high, self.dtype, shape) + # Cast `low` and `high` to ndarray for the dtype min and max for out of range tests + self.low, self.bounded_below = self._cast_low(low, dtype_min) + self.high, self.bounded_above = self._cast_high(high, dtype_max) - assert isinstance(low, np.ndarray) - assert ( - low.shape == shape - ), f"low.shape doesn't match provided shape, low.shape: {low.shape}, shape: {shape}" - assert isinstance(high, np.ndarray) - assert ( - high.shape == shape - ), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}" + # recheck shape for case where shape and (low or high) are provided + if self.low.shape != shape: + raise ValueError( + f"Box low.shape doesn't match provided shape, low.shape={self.low.shape}, shape={self.shape}" + ) + if self.high.shape != shape: + raise ValueError( + f"Box high.shape doesn't match provided shape, high.shape={self.high.shape}, shape={self.shape}" + ) - # check that we don't have invalid low or high - if np.any(low > high): + # check that low <= high + if np.any(self.low > self.high): raise ValueError( - f"Some low values are greater than high, low={low}, high={high}" + f"Box all low values must be less than or equal to high (some values break this), low={self.low}, high={self.high}" ) - if np.any(np.isposinf(low)): - raise ValueError(f"No low value can be equal to `np.inf`, low={low}") - if np.any(np.isneginf(high)): - raise ValueError(f"No high value can be equal to `-np.inf`, high={high}") - self._shape: tuple[int, ...] = shape + self.low_repr = array_short_repr(self.low) + self.high_repr = array_short_repr(self.high) - low_precision = get_precision(low.dtype) - high_precision = get_precision(high.dtype) - dtype_precision = get_precision(self.dtype) - if min(low_precision, high_precision) > dtype_precision: - gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}") - self.low = low.astype(self.dtype) - self.high = high.astype(self.dtype) + super().__init__(self.shape, self.dtype, seed) - self.low_repr = _short_repr(self.low) - self.high_repr = _short_repr(self.high) + def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]: + """Casts the input Box low value to ndarray with provided dtype. - super().__init__(self.shape, self.dtype, seed) + Args: + low: The input box low value + dtype_min: The dtype's minimum value + + Returns: + The updated low value and for what values the input is bounded (below) + """ + if is_float_integer(low): + bounded_below = -np.inf < np.full(self.shape, low, dtype=float) + + if np.isnan(low): + raise ValueError(f"No low value can be equal to `np.nan`, low={low}") + elif np.isneginf(low): + if self.dtype.kind == "i": # signed int + low = dtype_min + elif self.dtype.kind in {"u", "b"}: # unsigned int and bool + raise ValueError( + f"Box unsigned int dtype don't support `-np.inf`, low={low}" + ) + elif low < dtype_min: + raise ValueError( + f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}" + ) + + low = np.full(self.shape, low, dtype=self.dtype) + return low, bounded_below + else: # cast for low - array + if not isinstance(low, np.ndarray): + raise ValueError( + f"Box low must be a np.ndarray, integer, or float, actual type={type(low)}" + ) + elif not ( + np.issubdtype(low.dtype, np.floating) + or np.issubdtype(low.dtype, np.integer) + or low.dtype == np.bool_ + ): + raise ValueError( + f"Box low must be a floating, integer, or bool dtype, actual dtype={low.dtype}" + ) + elif np.any(np.isnan(low)): + raise ValueError(f"No low value can be equal to `np.nan`, low={low}") + + bounded_below = -np.inf < low + + if np.any(np.isneginf(low)): + if self.dtype.kind == "i": # signed int + low[np.isneginf(low)] = dtype_min + elif self.dtype.kind in {"u", "b"}: # unsigned int and bool + raise ValueError( + f"Box unsigned int dtype don't support `-np.inf`, low={low}" + ) + elif low.dtype != self.dtype and np.any(low < dtype_min): + raise ValueError( + f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}" + ) + + if ( + np.issubdtype(low.dtype, np.floating) + and np.issubdtype(self.dtype, np.floating) + and np.finfo(self.dtype).precision < np.finfo(low.dtype).precision + ): + gym.logger.warn( + f"Box low's precision lowered by casting to {self.dtype}, current low.dtype={low.dtype}" + ) + return low.astype(self.dtype), bounded_below + + def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]: + """Casts the input Box high value to ndarray with provided dtype. + + Args: + high: The input box high value + dtype_max: The dtype's maximum value + + Returns: + The updated high value and for what values the input is bounded (above) + """ + if is_float_integer(high): + bounded_above = np.full(self.shape, high, dtype=float) < np.inf + + if np.isnan(high): + raise ValueError(f"No high value can be equal to `np.nan`, high={high}") + elif np.isposinf(high): + if self.dtype.kind == "i": # signed int + high = dtype_max + elif self.dtype.kind in {"u", "b"}: # unsigned int + raise ValueError( + f"Box unsigned int dtype don't support `np.inf`, high={high}" + ) + elif high > dtype_max: + raise ValueError( + f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}" + ) + + high = np.full(self.shape, high, dtype=self.dtype) + return high, bounded_above + else: + if not isinstance(high, np.ndarray): + raise ValueError( + f"Box high must be a np.ndarray, integer, or float, actual type={type(high)}" + ) + elif not ( + np.issubdtype(high.dtype, np.floating) + or np.issubdtype(high.dtype, np.integer) + or high.dtype == np.bool_ + ): + raise ValueError( + f"Box high must be a floating or integer dtype, actual dtype={high.dtype}" + ) + elif np.any(np.isnan(high)): + raise ValueError(f"No high value can be equal to `np.nan`, high={high}") + + bounded_above = high < np.inf + + posinf = np.isposinf(high) + if np.any(posinf): + if self.dtype.kind == "i": # signed int + high[posinf] = dtype_max + elif self.dtype.kind in {"u", "b"}: # unsigned int + raise ValueError( + f"Box unsigned int dtype don't support `np.inf`, high={high}" + ) + elif high.dtype != self.dtype and np.any(dtype_max < high): + raise ValueError( + f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}" + ) + + if ( + np.issubdtype(high.dtype, np.floating) + and np.issubdtype(self.dtype, np.floating) + and np.finfo(self.dtype).precision < np.finfo(high.dtype).precision + ): + gym.logger.warn( + f"Box high's precision lowered by casting to {self.dtype}, current high.dtype={high.dtype}" + ) + return high.astype(self.dtype), bounded_above @property def shape(self) -> tuple[int, ...]: @@ -232,7 +392,24 @@ def sample(self, mask: None = None) -> NDArray[Any]: if self.dtype.kind in ["i", "u", "b"]: sample = np.floor(sample) - return sample.astype(self.dtype) + # clip values that would underflow/overflow + if np.issubdtype(self.dtype, np.signedinteger): + dtype_min = np.iinfo(self.dtype).min + 2 + dtype_max = np.iinfo(self.dtype).max - 2 + sample = sample.clip(min=dtype_min, max=dtype_max) + elif np.issubdtype(self.dtype, np.unsignedinteger): + dtype_min = np.iinfo(self.dtype).min + dtype_max = np.iinfo(self.dtype).max + sample = sample.clip(min=dtype_min, max=dtype_max) + + sample = sample.astype(self.dtype) + + # float64 values have lower than integer precision near int64 min/max, so clip + # again in case something has been cast to an out-of-bounds value + if self.dtype == np.int64: + sample = sample.clip(min=self.low, max=self.high) + + return sample def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" @@ -285,53 +462,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): # legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state if not hasattr(self, "low_repr"): - self.low_repr = _short_repr(self.low) + self.low_repr = array_short_repr(self.low) if not hasattr(self, "high_repr"): - self.high_repr = _short_repr(self.high) - - -def get_precision(dtype: np.dtype) -> SupportsFloat: - """Get precision of a data type.""" - if np.issubdtype(dtype, np.floating): - return np.finfo(dtype).precision - else: - return np.inf - - -def _broadcast( - value: SupportsFloat | NDArray[Any], - dtype: np.dtype, - shape: tuple[int, ...], -) -> NDArray[Any]: - """Handle infinite bounds and broadcast at the same time if needed. - - This is needed primarily because: - >>> import numpy as np - >>> np.full((2,), np.inf, dtype=np.int32) - array([-2147483648, -2147483648], dtype=int32) - """ - if is_float_integer(value): - if np.isneginf(value) and np.dtype(dtype).kind == "i": - value = np.iinfo(dtype).min + 2 - elif np.isposinf(value) and np.dtype(dtype).kind == "i": - value = np.iinfo(dtype).max - 2 - - return np.full(shape, value, dtype=dtype) - - elif isinstance(value, np.ndarray): - # this is needed because we can't stuff np.iinfo(int).min into an array of dtype float - casted_value = value.astype(dtype) - - # change bounds only if values are negative or positive infinite - if np.dtype(dtype).kind == "i": - casted_value[np.isneginf(value)] = np.iinfo(dtype).min + 2 - casted_value[np.isposinf(value)] = np.iinfo(dtype).max - 2 - - return casted_value - - else: - # only np.ndarray allowed beyond this point - raise TypeError( - f"Unknown dtype for `value`, expected `np.ndarray` or float/integer, got {type(value)}" - ) + self.high_repr = array_short_repr(self.high) diff --git a/tests/spaces/test_box.py b/tests/spaces/test_box.py index 85dab15ab..10bf47809 100644 --- a/tests/spaces/test_box.py +++ b/tests/spaces/test_box.py @@ -9,20 +9,54 @@ @pytest.mark.parametrize( - "box,expected_shape", + "dtype, error, message", [ - ( # Test with same 1-dim low and high shape - Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32), - (2,), + ( + None, + ValueError, + "Box dtype must be explicitly provided, cannot be None.", ), - ( # Test with same multi-dim low and high shape - Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32), - (2, 1), + (0, TypeError, "Cannot interpret '0' as a data type"), + ("unknown", TypeError, "data type 'unknown' not understood"), + (np.zeros(1), TypeError, "Cannot construct a dtype from an array"), + # disabled datatypes + ( + np.complex64, + ValueError, + "Invalid Box dtype (complex64), must be an integer, floating, or bool dtype", ), - ( # Test with scalar low high and different shape - Box(low=0, high=1, shape=(5, 2)), - (5, 2), + ( + complex, + ValueError, + "Invalid Box dtype (complex128), must be an integer, floating, or bool dtype", ), + ( + object, + ValueError, + "Invalid Box dtype (object), must be an integer, floating, or bool dtype", + ), + ( + str, + ValueError, + "Invalid Box dtype (", ), ( 0, 1, - {"shape": (None,)}, - AssertionError, - "Expected all shape elements to be an integer, actual type: (,)", + (None,), + TypeError, + "Expected all Box shape elements to be integer, actual type=(,)", ), ( 0, 1, - { - "shape": ( - 1, - None, - ) - }, - AssertionError, - "Expected all shape elements to be an integer, actual type: (, )", + (1, None), + TypeError, + "Expected all Box shape elements to be integer, actual type=(, )", ), ( 0, 1, - { - "shape": ( - np.int64(1), - None, - ) - }, - AssertionError, - "Expected all shape elements to be an integer, actual type: (, )", + (np.int64(1), None), + TypeError, + "Expected all Box shape elements to be integer, actual type=(, )", ), ( + np.zeros(3), + np.ones(2), None, - None, - {}, ValueError, - "Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: , high: ", + "Box low.shape and high.shape don't match, low.shape=(3,), high.shape=(2,)", ), ( - 0, - None, - {}, + np.zeros(2), + np.ones(2), + (3,), ValueError, - "Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: , high: ", + "Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)", ), ( - np.zeros(3), + np.zeros(2), + 1, + (3,), + ValueError, + "Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)", + ), + ( + 0, np.ones(2), - {}, - AssertionError, - "high.shape doesn't match provided shape, high.shape: (2,), shape: (3,)", + (3,), + ValueError, + "Box high.shape doesn't match provided shape, high.shape=(2,), shape=(3,)", ), ], ) -def test_init_errors(low, high, kwargs, error, message): - """Test all constructor errors.""" - with pytest.raises(error, match=f"^{re.escape(message)}$"): - Box(low=low, high=high, **kwargs) +def test_shape_errors(low, high, shape, error_type, message): + """Test errors due to shape mismatch.""" + with pytest.raises(error_type, match=f"^{re.escape(message)}$"): + Box(low=low, high=high, shape=shape) -def test_dtype_check(): +@pytest.mark.parametrize( + "low, high, dtype", + [ + # floats + (0, 65505.0, np.float16), + (-65505.0, 0, np.float16), + # signed int + (0, 32768, np.int16), + (-32769, 0, np.int16), + # unsigned int + (-1, 100, np.uint8), + (0, 300, np.uint8), + # boolean + (-1, 1, np.bool_), + (0, 2, np.bool_), + # array inputs + ( + np.array([-1, 0]), + np.array([0, 100]), + np.uint8, + ), + ( + np.array([[-1], [0]]), + np.array([[0], [100]]), + np.uint8, + ), + ( + np.array([0, 0]), + np.array([0, 300]), + np.uint8, + ), + ( + np.array([[0], [0]]), + np.array([[0], [300]]), + np.uint8, + ), + ], +) +def test_out_of_bounds_error(low, high, dtype): + with pytest.raises( + ValueError, match=re.escape("is out of bounds of the dtype range,") + ): + Box(low=low, high=high, dtype=dtype) + + +@pytest.mark.parametrize( + "low, high, dtype", + [ + # Floats + (np.nan, 0, np.float32), + (0, np.nan, np.float32), + (np.array([0, np.nan]), np.ones(2), np.float32), + # Signed ints + (np.nan, 0, np.int32), + (0, np.nan, np.int32), + (np.array([0, np.nan]), np.ones(2), np.int32), + # Unsigned ints + # (np.nan, 0, np.uint8), + # (0, np.nan, np.uint8), + # (np.array([0, np.nan]), np.ones(2), np.uint8), + (-np.inf, 1, np.uint8), + (np.array([-np.inf, 0]), 1, np.uint8), + (0, np.inf, np.uint8), + (0, np.array([1, np.inf]), np.uint8), + # boolean + (-np.inf, 1, np.bool_), + (0, np.inf, np.bool_), + ], +) +def test_invalid_low_high(low, high, dtype): + if dtype == np.uint8 or dtype == np.bool_: + with pytest.raises( + ValueError, match=re.escape("Box unsigned int dtype don't support") + ): + Box(low=low, high=high, dtype=dtype) + else: + with pytest.raises( + ValueError, match=re.escape("value can be equal to `np.nan`,") + ): + Box(low=low, high=high, dtype=dtype) + + +@pytest.mark.parametrize( + "low, high, dtype", + [ + # floats + (0, 1, float), + (0, 1, np.float64), + (0, 1, np.float32), + (0, 1, np.float16), + (np.zeros(2), np.ones(2), np.float32), + (np.zeros(2), 1, np.float32), + (-np.inf, 1, np.float32), + (np.array([-np.inf, 0]), 1, np.float32), + (0, np.inf, np.float32), + (0, np.array([np.inf, 1]), np.float32), + (-np.inf, np.inf, np.float32), + (np.full((2,), -np.inf), np.full((2,), np.inf), np.float32), + # signed ints + (0, 1, int), + (0, 1, np.int64), + (0, 1, np.int32), + (0, 1, np.int16), + (0, 1, np.int8), + (np.zeros(2), np.ones(2), np.int32), + (np.zeros(2), 1, np.int32), + (-np.inf, 1, np.int32), + (np.array([-np.inf, 0]), 1, np.int32), + (0, np.inf, np.int32), + (0, np.array([np.inf, 1]), np.int32), + # unsigned ints + (0, 1, np.uint64), + (0, 1, np.uint32), + (0, 1, np.uint16), + (0, 1, np.uint8), + # boolean + (0, 1, np.bool_), + ], +) +def test_valid_low_high(low, high, dtype): + with warnings.catch_warnings(record=True) as caught_warnings: + space = Box(low=low, high=high, dtype=dtype) + assert space.dtype == dtype + assert space.low.dtype == dtype + assert space.high.dtype == dtype + + space.seed(0) + sample = space.sample() + assert sample.dtype == dtype + assert space.contains(sample) + + for warn in caught_warnings: + if "precision lowered by casting to float32" not in warn.message.args[0]: + raise Exception(warn) + + +def test_contains_dtype(): """Tests the Box contains function with different dtypes.""" # Related Issues: # https://github.com/openai/gym/issues/2357 @@ -163,102 +294,54 @@ def test_dtype_check(): @pytest.mark.parametrize( - "space", + "lowhighshape", [ - Box(low=0, high=np.inf, shape=(2,), dtype=np.int32), - Box(low=0, high=np.inf, shape=(2,), dtype=np.float32), - Box(low=0, high=np.inf, shape=(2,), dtype=np.int64), - Box(low=0, high=np.inf, shape=(2,), dtype=np.float64), - Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32), - Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32), - Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64), - Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64), - Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32), - Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32), - Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64), - Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64), - Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32), - Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32), - Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64), - Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64), - Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32), - Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32), - Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64), - Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64), - Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32), - Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32), - Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64), - Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64), - Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32), - Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32), - Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64), - Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64), + dict(low=0, high=np.inf, shape=(2,)), + dict(low=-np.inf, high=0, shape=(2,)), + dict(low=-np.inf, high=np.inf, shape=(2,)), + dict(low=0, high=np.inf, shape=(2, 3)), + dict(low=-np.inf, high=0, shape=(2, 3)), + dict(low=-np.inf, high=np.inf, shape=(2, 3)), + dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])), ], ) -def test_infinite_space(space): +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) +def test_infinite_space(lowhighshape, dtype): """ To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low` are both modified within the init, we check for infinite when we know it's not 0 """ + space = Box(**lowhighshape, dtype=dtype) - assert np.all( - space.low < space.high - ), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})" + assert np.all(space.low < space.high) - space.seed(0) - sample = space.sample() - - # check if space contains sample - assert ( - sample in space - ), f"Sample ({sample}) not inside space according to `space.contains()`" - - # manually check that the sign of the sample is within the bounds - assert np.all( - np.sign(sample) <= np.sign(space.high) - ), f"Sign of sample ({sample}) is less than space upper bound ({space.high})" - assert np.all( - np.sign(space.low) <= np.sign(sample) - ), f"Sign of sample ({sample}) is more than space lower bound ({space.low})" - - # check that int bounds are bounded for everything - # but floats are unbounded for infinite - if np.any(space.high != 0): - assert ( - space.is_bounded("above") is False - ), "inf upper bound supposed to be unbounded" - else: - assert ( - space.is_bounded("above") is True - ), "non-inf upper bound supposed to be bounded" - - if np.any(space.low != 0): - assert ( - space.is_bounded("below") is False - ), "inf lower bound supposed to be unbounded" - else: - assert ( - space.is_bounded("below") is True - ), "non-inf lower bound supposed to be bounded" - - if np.any(space.low != 0) or np.any(space.high != 0): - assert space.is_bounded("both") is False - else: - assert space.is_bounded("both") is True + # check that int bounds are bounded for everything but floats are unbounded for infinite + assert space.is_bounded("above") is not np.any(space.high != 0) + assert space.is_bounded("below") is not np.any(space.low != 0) + assert space.is_bounded("both") is not ( + np.any(space.high != 0) | np.any(space.high != 0) + ) # check for dtype - assert ( - space.high.dtype == space.dtype - ), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'" - assert ( - space.low.dtype == space.dtype - ), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'" + assert space.high.dtype == space.dtype + assert space.low.dtype == space.dtype with pytest.raises( ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:" ): space.is_bounded("test") + # Check sample + space.seed(0) + sample = space.sample() + + # check if space contains sample + assert sample in space + + # manually check that the sign of the sample is within the bounds + assert np.all(np.sign(sample) <= np.sign(space.high)) + assert np.all(np.sign(space.low) <= np.sign(sample)) + def test_legacy_state_pickling(): legacy_state = { @@ -290,56 +373,3 @@ def test_sample_mask(): match=re.escape("Box.sample cannot be provided a mask, actual value: "), ): space.sample(mask=np.array([0, 1, 0], dtype=np.int8)) - - -@pytest.mark.parametrize( - "low, high, shape, dtype, reason", - [ - ( - 5.0, - 3.0, - (), - np.float32, - "Some low values are greater than high, low=5.0, high=3.0", - ), - ( - np.array([5.0, 6.0]), - np.array([1.0, 5.99]), - (2,), - np.float32, - "Some low values are greater than high, low=[5. 6.], high=[1. 5.99]", - ), - ( - np.inf, - np.inf, - (), - np.float32, - "No low value can be equal to `np.inf`, low=inf", - ), - ( - np.array([0, np.inf]), - np.array([np.inf, np.inf]), - (2,), - np.float32, - "No low value can be equal to `np.inf`, low=[ 0. inf]", - ), - ( - -np.inf, - -np.inf, - (), - np.float32, - "No high value can be equal to `-np.inf`, high=-inf", - ), - ( - np.array([-np.inf, -np.inf]), - np.array([0, -np.inf]), - (2,), - np.float32, - "No high value can be equal to `-np.inf`, high=[ 0. -inf]", - ), - ], -) -def test_invalid_low_high(low, high, dtype, shape, reason): - """Tests that we don't allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`.""" - with pytest.raises(ValueError, match=re.escape(reason)): - Box(low=low, high=high, dtype=dtype, shape=shape) diff --git a/tests/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py index cf0db8b69..783c73944 100644 --- a/tests/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -45,7 +45,7 @@ def custom_environments(): ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), ("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}), ("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}), - ("CartPole-v1", "DtypeObservation", {"dtype": np.int32}), + ("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}), # ("CartPole-v1", "RenderObservation", {}), # not implemented # ("CartPole-v1", "TimeAwareObservation", {}), # not implemented # ("CartPole-v1", "FrameStackObservation", {}), # not implemented