diff --git a/src/pymgrid/modules/base/base_module.py b/src/pymgrid/modules/base/base_module.py index e673d17d..bbddd13f 100644 --- a/src/pymgrid/modules/base/base_module.py +++ b/src/pymgrid/modules/base/base_module.py @@ -149,9 +149,10 @@ def step(self, action, normalized=True): except (IndexError, TypeError): if not isinstance(denormalized_action, (float, int)): try: - flat_dim = np.product(denormalized_action.shape) - assert flat_dim == 0 - except (AttributeError, AssertionError): + flat_dim = np.prod(denormalized_action.shape) + if flat_dim != 0: + raise ValueError(f'Bad action {denormalized_action}') + except AttributeError: raise ValueError(f'Bad action {denormalized_action}') else: denormalized_action = 0.0