Skip to content

Commit

Permalink
fix control sequences api
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 3, 2024
1 parent 7f2bbb4 commit b30a508
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymdp/jax/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_marginals(q_pi, policies, num_controls):

return action_marginals

def sample_action(q_pi, policies, num_controls, action_selection="deterministic", alpha=16.0, rng_key=None):
def sample_action(policies, num_controls, q_pi, action_selection="deterministic", alpha=16.0, rng_key=None):
"""
Samples an action from posterior marginals, one action per control factor.
Expand Down Expand Up @@ -85,7 +85,7 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic"

return jnp.array(selected_policy)

def sample_policy(q_pi, policies, num_controls, action_selection="deterministic", alpha = 16.0, rng_key=None):
def sample_policy(policies, q_pi, action_selection="deterministic", alpha = 16.0, rng_key=None):

if action_selection == "deterministic":
policy_idx = jnp.argmax(q_pi)
Expand Down

0 comments on commit b30a508

Please sign in to comment.