From b30a508b29e4a959bd437207b34c9791edffe1d2 Mon Sep 17 00:00:00 2001 From: Dimitrije Markovic <5038100+dimarkov@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:11:43 +0200 Subject: [PATCH] fix control sequences api --- pymdp/jax/control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymdp/jax/control.py b/pymdp/jax/control.py index 177a5a41..b3cc0ad2 100644 --- a/pymdp/jax/control.py +++ b/pymdp/jax/control.py @@ -46,7 +46,7 @@ def get_marginals(q_pi, policies, num_controls): return action_marginals -def sample_action(q_pi, policies, num_controls, action_selection="deterministic", alpha=16.0, rng_key=None): +def sample_action(policies, num_controls, q_pi, action_selection="deterministic", alpha=16.0, rng_key=None): """ Samples an action from posterior marginals, one action per control factor. @@ -85,7 +85,7 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic" return jnp.array(selected_policy) -def sample_policy(q_pi, policies, num_controls, action_selection="deterministic", alpha = 16.0, rng_key=None): +def sample_policy(policies, q_pi, action_selection="deterministic", alpha = 16.0, rng_key=None): if action_selection == "deterministic": policy_idx = jnp.argmax(q_pi)