Skip to content

Commit

Permalink
Merge pull request #146 from infer-actively/ovf-with-example
Browse files Browse the repository at this point in the history
Ovf with example
  • Loading branch information
conorheins authored Jun 17, 2024
2 parents 62a6e87 + 8ee1426 commit 473e942
Show file tree
Hide file tree
Showing 7 changed files with 812 additions and 254 deletions.
397 changes: 397 additions & 0 deletions examples/inference_and_learning/inference_methods_comparison.ipynb

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,55 @@ def learning(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB

return agent


@vmap
def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs):
agent = self
beliefs_B = beliefs_A if beliefs_B is None else beliefs_B
if self.inference_algo == 'ovf':
smoothed_marginals_and_joints = inference.smoothing_ovf(beliefs_A, self.B, actions)
marginal_beliefs = smoothed_marginals_and_joints[0]
joint_beliefs = smoothed_marginals_and_joints[1]
else:
marginal_beliefs = beliefs_A
if self.learn_B:
nf = len(beliefs_B)
joint_fn = lambda f: [beliefs_B[f][1:]] + [beliefs_B[f_idx][:-1] for f_idx in self.B_dependencies[f]]
joint_beliefs = jtu.tree_map(joint_fn, list(range(nf)))

if self.learn_A:
qA, E_qA = learning.update_obs_likelihood_dirichlet(
self.pA,
outcomes,
marginal_beliefs,
A_dependencies=self.A_dependencies,
num_obs=self.num_obs,
onehot_obs=self.onehot_obs,
lr=lr_pA,
)

agent = tree_at(lambda x: (x.A, x.pA), agent, (E_qA, qA))

if self.learn_B:
assert beliefs_B[0].shape[0] == actions.shape[0] + 1
qB, E_qB = learning.update_state_transition_dirichlet(
self.pB,
joint_beliefs,
actions,
num_controls=self.num_controls,
lr=lr_pB
)

# if you have updated your beliefs about transitions, you need to re-compute the I matrix used for inductive inferenece
if self.use_inductive and self.H is not None:
I_updated = control.generate_I_matrix(self.H, E_qB, self.inductive_threshold, self.inductive_depth)
else:
I_updated = self.I

agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated))

@vmap
def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None):
def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None):
"""
Update approximate posterior over hidden states by solving variational inference problem, given an observation.
Expand Down
92 changes: 90 additions & 2 deletions pymdp/jax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import jax.numpy as jnp
from .algos import run_factorized_fpi, run_mmp, run_vmp
from jax import tree_util as jtu
from jax import tree_util as jtu, lax
from jax.experimental.sparse._base import JAXSparse
from jaxtyping import Array

def update_posterior_states(
A,
Expand Down Expand Up @@ -55,4 +57,90 @@ def update_posterior_states(
qs_hist = qs

return qs_hist


def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = jnp.moveaxis(b[..., actions], -1, 0)
# time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
qs_backward_cond = jnp.moveaxis(
qs_joint / qs_joint.sum(-2, keepdims=True), -2, -1
)
# tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
# qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, backward_b):
qs_joint = backward_b * qs_smooth_past
qs_smooth = qs_joint.sum(-1)

return qs_smooth, (qs_smooth, qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
qs_backward_cond,
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs
qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]

def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, t):
qs_joint = qs_backward_cond[t] * qs_smooth_past
qs_smooth = qs_joint.sum(-1)

return qs_smooth.todense(), (qs_smooth.todense(), qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
jnp.arange(qs_backward_cond.shape[0]),
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs

qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]


def smoothing_ovf(filtered_post, B, past_actions):
assert len(filtered_post) == len(B)
nf = len(B) # number of factors

joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f])

marginals_and_joints = []
for b, qs, f in zip(B, filtered_post, list(range(nf))):
marginals_and_joints.append( joint(b, qs, f) )

return marginals_and_joints



Loading

0 comments on commit 473e942

Please sign in to comment.