Skip to content

Commit

Permalink
remove dependence on multimethod
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jun 17, 2024
1 parent 1219167 commit 475dc51
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions pymdp/jax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.numpy as jnp
from .algos import run_factorized_fpi, run_mmp, run_vmp
from jax import tree_util as jtu, lax
from multimethod import multimethod
from jax.experimental.sparse._base import JAXSparse
from jaxtyping import Array

Expand Down Expand Up @@ -59,8 +58,7 @@ def update_posterior_states(

return qs_hist

@multimethod
def joint_dist_factor(b: Array, filtered_qs, actions):
def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

Expand Down Expand Up @@ -97,8 +95,7 @@ def step_fn(qs_smooth_past, backward_b):
qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]

@multimethod
def joint_dist_factor(b: JAXSparse, filtered_qs, actions):
def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

Expand Down Expand Up @@ -136,14 +133,12 @@ def step_fn(qs_smooth_past, t):
def smoothing_ovf(filtered_post, B, past_actions):
assert len(filtered_post) == len(B)
nf = len(B) # number of factors
joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f])
# marginals_and_joints = jtu.tree_map(
# joint, B, filtered_post, list(range(nf)))

joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f])

marginals_and_joints = []
for b, qs, f in zip(B, filtered_post, list(range(nf))):
marginals_and_joints_f = joint(b, qs, f)
marginals_and_joints.append(marginals_and_joints_f)
marginals_and_joints.append( joint(b, qs, f) )

return marginals_and_joints

Expand Down

0 comments on commit 475dc51

Please sign in to comment.