diff --git a/pymdp/jax/task.py b/pymdp/jax/task.py index 5de0315e..0f349d94 100644 --- a/pymdp/jax/task.py +++ b/pymdp/jax/task.py @@ -9,9 +9,9 @@ def select_probs(positions, matrix, dependency_list, actions=None): args = tuple(p for i, p in enumerate(positions) if i in dependency_list) - args += () if actions is None else (actions,) + args = args + (actions,) if actions is not None else args - return matrix[..., *args] + return matrix[(...,) + args] def cat_sample(key, p): a = jnp.arange(p.shape[-1])