diff --git a/pymdp/jax/inference.py b/pymdp/jax/inference.py index 350d3662..4edd48d5 100644 --- a/pymdp/jax/inference.py +++ b/pymdp/jax/inference.py @@ -73,9 +73,9 @@ def step_fn(qs_smooth, xs): if isinstance(norm, JAXSparse): norm = sparse.todense(norm) norm = jnp.where(norm == 0, eps, norm) - qs_backward_cond = (qs_j / norm).T - qs_joint = qs_backward_cond * qs_smooth - qs_smooth = qs_joint.sum(-1) + qs_backward_cond = qs_j / norm + qs_joint = qs_backward_cond * jnp.expand_dims(qs_smooth, -1) + qs_smooth = qs_joint.sum(-2) if isinstance(qs_smooth, JAXSparse): qs_smooth = sparse.todense(qs_smooth) @@ -106,9 +106,11 @@ def smoothing_ovf(filtered_post, B, past_actions): joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f]) - marginals_and_joints = [] + marginals_and_joints = ([], []) for b, qs, f in zip(B, filtered_post, list(range(nf))): - marginals_and_joints.append( joint(b, qs, f) ) + marginals, joints = joint(b, qs, f) + marginals_and_joints[0].append(marginals) + marginals_and_joints[1].append(joints) return marginals_and_joints diff --git a/pymdp/jax/learning.py b/pymdp/jax/learning.py index 6c681751..32ada016 100644 --- a/pymdp/jax/learning.py +++ b/pymdp/jax/learning.py @@ -4,6 +4,7 @@ from .maths import multidimensional_outer, dirichlet_expected_value from jax.tree_util import tree_map +from jaxtyping import Array from jax import vmap, nn def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0): @@ -58,6 +59,7 @@ def update_state_transition_dirichlet_f(pB_f, actions_f, joint_qs_f, lr=1.0): # \otimes is a multidimensional outer product, not just a outer product of two vectors # \kappa is an optional learning rate + joint_qs_f = [joint_qs_f] if isinstance(joint_qs_f, Array) else joint_qs_f dfdb = vmap(multidimensional_outer)(joint_qs_f + [actions_f]).sum(axis=0) qB_f = pB_f + lr * dfdb