From e22b7ee06c8701f761f2ad69b1e665824ff3d3ff Mon Sep 17 00:00:00 2001 From: ahalev Date: Thu, 8 Aug 2024 15:24:08 -0700 Subject: [PATCH] fix for numpy>=2.0 --- src/pymgrid/modules/base/base_module.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pymgrid/modules/base/base_module.py b/src/pymgrid/modules/base/base_module.py index e673d17d..bbddd13f 100644 --- a/src/pymgrid/modules/base/base_module.py +++ b/src/pymgrid/modules/base/base_module.py @@ -149,9 +149,10 @@ def step(self, action, normalized=True): except (IndexError, TypeError): if not isinstance(denormalized_action, (float, int)): try: - flat_dim = np.product(denormalized_action.shape) - assert flat_dim == 0 - except (AttributeError, AssertionError): + flat_dim = np.prod(denormalized_action.shape) + if flat_dim != 0: + raise ValueError(f'Bad action {denormalized_action}') + except AttributeError: raise ValueError(f'Bad action {denormalized_action}') else: denormalized_action = 0.0