Skip to content

Commit

Permalink
fixes for computations with none values for pB and pA
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 1, 2024
1 parent 264ee7d commit 865ce9e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 0 additions & 1 deletion pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def multiaction_probabilities(self, q_pi: Array):
)
marginals = jnp.where(locs, q_pi, 0.).sum(-1)

# assert jnp.isclose(jnp.sum(marginals), 1.) # this fails inside scan
return marginals

@vmap
Expand Down
7 changes: 5 additions & 2 deletions pymdp/jax/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,11 @@ def scan_body(carry, t):

inductive_value = calc_inductive_value_t(qs_init, qs_next, I, epsilon=inductive_epsilon) if use_inductive else 0.

param_info_gain = calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0.
param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0.
param_info_gain = 0.
if pA is not None:
param_info_gain += calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0.
if pB is not None:
param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0.

neg_G += info_gain + utility - param_info_gain + inductive_value

Expand Down

0 comments on commit 865ce9e

Please sign in to comment.