From 27bf12679a0a2e0ada392b183ab797d5c6650c21 Mon Sep 17 00:00:00 2001 From: ariel Date: Tue, 5 Dec 2023 00:36:51 +0100 Subject: [PATCH] Add proper handling of nullable boxes --- gymnasium/spaces/box.py | 10 ++++++++-- gymnasium/spaces/utils.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 36206c512..3fc6be7ec 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -246,11 +246,17 @@ def contains(self, x: Any) -> bool: except (ValueError, TypeError): return False + bounded_below = x >= self.low + bounded_above = x <= self.high + + bounded = bounded_below & bounded_above + if self.nullable: + bounded |= np.isnan(x) + return bool( np.can_cast(x.dtype, self.dtype) and x.shape == self.shape - and np.all(x >= self.low) - and np.all(x <= self.high) + and np.all(bounded) ) def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]: diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 9e270d980..3df8db147 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -575,4 +575,4 @@ def _flatten_space_oneof(space: OneOf) -> Box: high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)]) dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) - return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) + return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype, nullable=True)