Skip to content

Commit

Permalink
fix smoothing with ovf and learning
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 3, 2024
1 parent b30a508 commit fe4313d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 7 additions & 5 deletions pymdp/jax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def step_fn(qs_smooth, xs):
if isinstance(norm, JAXSparse):
norm = sparse.todense(norm)
norm = jnp.where(norm == 0, eps, norm)
qs_backward_cond = (qs_j / norm).T
qs_joint = qs_backward_cond * qs_smooth
qs_smooth = qs_joint.sum(-1)
qs_backward_cond = qs_j / norm
qs_joint = qs_backward_cond * jnp.expand_dims(qs_smooth, -1)
qs_smooth = qs_joint.sum(-2)
if isinstance(qs_smooth, JAXSparse):
qs_smooth = sparse.todense(qs_smooth)

Expand Down Expand Up @@ -106,9 +106,11 @@ def smoothing_ovf(filtered_post, B, past_actions):

joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f])

marginals_and_joints = []
marginals_and_joints = ([], [])
for b, qs, f in zip(B, filtered_post, list(range(nf))):
marginals_and_joints.append( joint(b, qs, f) )
marginals, joints = joint(b, qs, f)
marginals_and_joints[0].append(marginals)
marginals_and_joints[1].append(joints)

return marginals_and_joints

Expand Down
2 changes: 2 additions & 0 deletions pymdp/jax/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .maths import multidimensional_outer, dirichlet_expected_value
from jax.tree_util import tree_map
from jaxtyping import Array
from jax import vmap, nn

def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0):
Expand Down Expand Up @@ -58,6 +59,7 @@ def update_state_transition_dirichlet_f(pB_f, actions_f, joint_qs_f, lr=1.0):
# \otimes is a multidimensional outer product, not just a outer product of two vectors
# \kappa is an optional learning rate

joint_qs_f = [joint_qs_f] if isinstance(joint_qs_f, Array) else joint_qs_f
dfdb = vmap(multidimensional_outer)(joint_qs_f + [actions_f]).sum(axis=0)
qB_f = pB_f + lr * dfdb

Expand Down

0 comments on commit fe4313d

Please sign in to comment.