diff --git a/pymdp/jax/control.py b/pymdp/jax/control.py index 177a5a41..b3cc0ad2 100644 --- a/pymdp/jax/control.py +++ b/pymdp/jax/control.py @@ -46,7 +46,7 @@ def get_marginals(q_pi, policies, num_controls): return action_marginals -def sample_action(q_pi, policies, num_controls, action_selection="deterministic", alpha=16.0, rng_key=None): +def sample_action(policies, num_controls, q_pi, action_selection="deterministic", alpha=16.0, rng_key=None): """ Samples an action from posterior marginals, one action per control factor. @@ -85,7 +85,7 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic" return jnp.array(selected_policy) -def sample_policy(q_pi, policies, num_controls, action_selection="deterministic", alpha = 16.0, rng_key=None): +def sample_policy(policies, q_pi, action_selection="deterministic", alpha = 16.0, rng_key=None): if action_selection == "deterministic": policy_idx = jnp.argmax(q_pi)