diff --git a/jumanji/environments/routing/pacman/env.py b/jumanji/environments/routing/pacman/env.py index 434cbaf84..47138c69c 100644 --- a/jumanji/environments/routing/pacman/env.py +++ b/jumanji/environments/routing/pacman/env.py @@ -444,8 +444,8 @@ def check_power_up( ps = jnp.array([ps_y, ps_x]) # Check if player and power_up position are shared - valid = (ps == power_up_locations).all(axis=-1).any() - eat = 1 * valid + on_powerup = (ps == power_up_locations).all(axis=-1).any() + eat = on_powerup.astype(int) mask = (ps == power_up_locations).all(axis=-1) invert_mask = mask != True # type: ignore # noqa: E712 invert_mask = invert_mask.reshape(4, 1)