Skip to content

Commit

Permalink
fix learning with None pA and pB values because of the recent jax cha…
Browse files Browse the repository at this point in the history
…nges
  • Loading branch information
dimarkov committed Nov 26, 2024
1 parent 88651fc commit 69dee71
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions pymdp/jax/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ def update_obs_likelihood_dirichlet(pA, A, obs, qs, *, A_dependencies, onehot_ob
""" JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet`` """

obs_m = lambda o, dim: nn.one_hot(o, dim) if not onehot_obs else o
update_A_fn = lambda pA_m, o_m, dim, dependencies_m: update_obs_likelihood_dirichlet_m(
update_A_fn = lambda pA_m, o_m, dim, dependencies_m: None if pA_m is None else update_obs_likelihood_dirichlet_m(
pA_m, obs_m(o_m, dim), qs, dependencies_m, lr=lr
)
result = tree_map(update_A_fn, pA, obs, num_obs, A_dependencies)

result = tree_map(
update_A_fn,
pA, obs, num_obs, A_dependencies,
is_leaf=lambda x: x is None
)

qA = []
E_qA = []
for i, r in enumerate(result):
Expand Down Expand Up @@ -70,22 +76,28 @@ def update_state_transition_dirichlet_f(pB_f, actions_f, joint_qs_f, lr=1.0):

return qB_f, dirichlet_expected_value(qB_f)

def update_state_transition_dirichlet(pB, joint_beliefs, actions, *, num_controls, lr):
def update_state_transition_dirichlet(pB, B, joint_beliefs, actions, *, num_controls, lr):

nf = len(pB)
actions_onehot_fn = lambda f, dim: nn.one_hot(actions[..., f], dim, axis=-1)
update_B_f_fn = lambda pB_f, joint_qs_f, f, na: update_state_transition_dirichlet_f(
update_B_f_fn = lambda pB_f, joint_qs_f, f, na: None if pB_f is None else update_state_transition_dirichlet_f(
pB_f, actions_onehot_fn(f, na), joint_qs_f, lr=lr
)
result = tree_map(
update_B_f_fn, pB, joint_beliefs, list(range(nf)), num_controls
update_B_f_fn,
pB, joint_beliefs, list(range(nf)), num_controls,
is_leaf=lambda x: x is None
)

qB = []
E_qB = []
for r in result:
qB.append(r[0])
E_qB.append(r[1])
for i, r in enumerate(result):
if r is None:
qB.append(None)
E_qB.append(B[i])
else:
qB.append(r[0])
E_qB.append(r[1])

return qB, E_qB

Expand Down

0 comments on commit 69dee71

Please sign in to comment.