diff --git a/pymdp/jax/learning.py b/pymdp/jax/learning.py index 499b94a3..ba7d9608 100644 --- a/pymdp/jax/learning.py +++ b/pymdp/jax/learning.py @@ -34,10 +34,16 @@ def update_obs_likelihood_dirichlet(pA, A, obs, qs, *, A_dependencies, onehot_ob """ JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet`` """ obs_m = lambda o, dim: nn.one_hot(o, dim) if not onehot_obs else o - update_A_fn = lambda pA_m, o_m, dim, dependencies_m: update_obs_likelihood_dirichlet_m( + update_A_fn = lambda pA_m, o_m, dim, dependencies_m: None if pA_m is None else update_obs_likelihood_dirichlet_m( pA_m, obs_m(o_m, dim), qs, dependencies_m, lr=lr ) - result = tree_map(update_A_fn, pA, obs, num_obs, A_dependencies) + + result = tree_map( + update_A_fn, + pA, obs, num_obs, A_dependencies, + is_leaf=lambda x: x is None + ) + qA = [] E_qA = [] for i, r in enumerate(result): @@ -70,22 +76,28 @@ def update_state_transition_dirichlet_f(pB_f, actions_f, joint_qs_f, lr=1.0): return qB_f, dirichlet_expected_value(qB_f) -def update_state_transition_dirichlet(pB, joint_beliefs, actions, *, num_controls, lr): +def update_state_transition_dirichlet(pB, B, joint_beliefs, actions, *, num_controls, lr): nf = len(pB) actions_onehot_fn = lambda f, dim: nn.one_hot(actions[..., f], dim, axis=-1) - update_B_f_fn = lambda pB_f, joint_qs_f, f, na: update_state_transition_dirichlet_f( + update_B_f_fn = lambda pB_f, joint_qs_f, f, na: None if pB_f is None else update_state_transition_dirichlet_f( pB_f, actions_onehot_fn(f, na), joint_qs_f, lr=lr ) result = tree_map( - update_B_f_fn, pB, joint_beliefs, list(range(nf)), num_controls + update_B_f_fn, + pB, joint_beliefs, list(range(nf)), num_controls, + is_leaf=lambda x: x is None ) qB = [] E_qB = [] - for r in result: - qB.append(r[0]) - E_qB.append(r[1]) + for i, r in enumerate(result): + if r is None: + qB.append(None) + E_qB.append(B[i]) + else: + qB.append(r[0]) + E_qB.append(r[1]) return qB, E_qB