Skip to content

Commit

Permalink
fix sample policies and actions to pass correctly logits to categoric…
Browse files Browse the repository at this point in the history
…al sampler
  • Loading branch information
dimarkov committed Jun 5, 2024
1 parent 1952165 commit 55fcc5f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pymdp/jax/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic"
if action_selection == 'deterministic':
selected_policy = jtu.tree_map(lambda x: jnp.argmax(x, -1), marginal)
elif action_selection == 'stochastic':
selected_policy = jtu.tree_map(lambda x: jr.categorical(rng_key, nn.softmax(alpha * log_stable(x))), marginal)
logits = lambda x: alpha * log_stable(x)
selected_policy = jtu.tree_map(lambda x: jr.categorical(rng_key, logits(x)), marginal)
else:
raise NotImplementedError

Expand All @@ -89,8 +90,8 @@ def sample_policy(q_pi, policies, num_controls, action_selection="deterministic"
if action_selection == "deterministic":
policy_idx = jnp.argmax(q_pi)
elif action_selection == "stochastic":
p_policies = nn.softmax(log_stable(q_pi) * alpha)
policy_idx = jr.categorical(rng_key, p_policies)
log_p_policies = log_stable(q_pi) * alpha
policy_idx = jr.categorical(rng_key, log_p_policies)

selected_multiaction = policies[policy_idx, 0]
return selected_multiaction
Expand Down

0 comments on commit 55fcc5f

Please sign in to comment.