Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ovf with example #146

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
397 changes: 397 additions & 0 deletions examples/inference_and_learning/inference_methods_comparison.ipynb

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,55 @@ def learning(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB

return agent


@vmap
def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs):
agent = self
beliefs_B = beliefs_A if beliefs_B is None else beliefs_B
if self.inference_algo == 'ovf':
smoothed_marginals_and_joints = inference.smoothing_ovf(beliefs_A, self.B, actions)
marginal_beliefs = smoothed_marginals_and_joints[0]
joint_beliefs = smoothed_marginals_and_joints[1]
else:
marginal_beliefs = beliefs_A
if self.learn_B:
nf = len(beliefs_B)
joint_fn = lambda f: [beliefs_B[f][1:]] + [beliefs_B[f_idx][:-1] for f_idx in self.B_dependencies[f]]
joint_beliefs = jtu.tree_map(joint_fn, list(range(nf)))

if self.learn_A:
qA, E_qA = learning.update_obs_likelihood_dirichlet(
self.pA,
outcomes,
marginal_beliefs,
A_dependencies=self.A_dependencies,
num_obs=self.num_obs,
onehot_obs=self.onehot_obs,
lr=lr_pA,
)

agent = tree_at(lambda x: (x.A, x.pA), agent, (E_qA, qA))

if self.learn_B:
assert beliefs_B[0].shape[0] == actions.shape[0] + 1
qB, E_qB = learning.update_state_transition_dirichlet(
self.pB,
joint_beliefs,
actions,
num_controls=self.num_controls,
lr=lr_pB
)

# if you have updated your beliefs about transitions, you need to re-compute the I matrix used for inductive inferenece
if self.use_inductive and self.H is not None:
I_updated = control.generate_I_matrix(self.H, E_qB, self.inductive_threshold, self.inductive_depth)
else:
I_updated = self.I

agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated))

@vmap
def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None):
def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None):
"""
Update approximate posterior over hidden states by solving variational inference problem, given an observation.

Expand Down
92 changes: 90 additions & 2 deletions pymdp/jax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import jax.numpy as jnp
from .algos import run_factorized_fpi, run_mmp, run_vmp
from jax import tree_util as jtu
from jax import tree_util as jtu, lax
from jax.experimental.sparse._base import JAXSparse
from jaxtyping import Array

def update_posterior_states(
A,
Expand Down Expand Up @@ -55,4 +57,90 @@ def update_posterior_states(
qs_hist = qs

return qs_hist


def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = jnp.moveaxis(b[..., actions], -1, 0)
# time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
qs_backward_cond = jnp.moveaxis(
qs_joint / qs_joint.sum(-2, keepdims=True), -2, -1
)
# tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
# qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, backward_b):
qs_joint = backward_b * qs_smooth_past
qs_smooth = qs_joint.sum(-1)

return qs_smooth, (qs_smooth, qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
qs_backward_cond,
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs
qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]

def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, t):
qs_joint = qs_backward_cond[t] * qs_smooth_past
qs_smooth = qs_joint.sum(-1)

return qs_smooth.todense(), (qs_smooth.todense(), qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
jnp.arange(qs_backward_cond.shape[0]),
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs

qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]


def smoothing_ovf(filtered_post, B, past_actions):
assert len(filtered_post) == len(B)
nf = len(B) # number of factors

joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f])

marginals_and_joints = []
for b, qs, f in zip(B, filtered_post, list(range(nf))):
marginals_and_joints.append( joint(b, qs, f) )

return marginals_and_joints



Loading
Loading