From a0b2ccdca5547bd36598ecbf5533e96186fb8674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Scott=20Carroll=20=F0=9F=90=99?= Date: Wed, 12 Jun 2024 15:09:31 -0400 Subject: [PATCH] Try updating syntax so coverage isn't broken; if it does not work try passing in --ignore-errors=True --- pymdp/jax/task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymdp/jax/task.py b/pymdp/jax/task.py index 5de0315e..0f349d94 100644 --- a/pymdp/jax/task.py +++ b/pymdp/jax/task.py @@ -9,9 +9,9 @@ def select_probs(positions, matrix, dependency_list, actions=None): args = tuple(p for i, p in enumerate(positions) if i in dependency_list) - args += () if actions is None else (actions,) + args = args + (actions,) if actions is not None else args - return matrix[..., *args] + return matrix[(...,) + args] def cat_sample(key, p): a = jnp.arange(p.shape[-1])