diff --git a/jumanji/environments/routing/pacman/utils.py b/jumanji/environments/routing/pacman/utils.py index 625ccbcdd..0420aef77 100644 --- a/jumanji/environments/routing/pacman/utils.py +++ b/jumanji/environments/routing/pacman/utils.py @@ -324,7 +324,7 @@ def get_valid_positions(pos: chex.Array) -> Any: valid_no_back = valids * ghost_mask # Get distances of valid locations valid_no_back_d = valid_no_back * ghost_dist - invert_mask = valid_no_back != True # type: ignore # noqa: E712 + invert_mask = ~valid_no_back invert_mask = invert_mask * jnp.inf # Set distance of all invalid areas to infinity valid_no_back_d = valid_no_back_d + invert_mask