Skip to content

Commit

Permalink
changed the backward messages in run_mmp_factorized to match how th…
Browse files Browse the repository at this point in the history
…ey are computed in `pymdp.jax.algos.get_mmp_messages`
  • Loading branch information
conorheins committed Jun 5, 2024
1 parent e97d6c7 commit 4f4b7c4
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions pymdp/algos/mmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from pymdp.utils import to_obj_array, get_model_dimensions, obj_array, obj_array_zeros, obj_array_uniform
from pymdp.maths import spm_dot, spm_norm, softmax, calc_free_energy, spm_log_single
from pymdp.maths import spm_dot, spm_norm, softmax, calc_free_energy, spm_log_single, factor_dot_flex
import copy

def run_mmp(
Expand Down Expand Up @@ -224,6 +224,8 @@ def run_mmp_factorized(
joint_loglikelihood += lh_seq[t][m].reshape(reshape_dims) # add up all the log-likelihoods after reshaping them to the global common dimensions of all hidden state factors
joint_lh_seq[t] = joint_loglikelihood

# compute inverse B dependencies, which is a list that for each hidden state factor, lists the indices of the other hidden state factors that it 'drives' or is a parent of in the HMM graphical model
inv_B_deps = [[i for i, d in enumerate(B_factor_list) if f in d] for f in range(num_factors)]
for itr in range(num_iter):
F = 0.0 # reset variational free energy (accumulated over time and factors, but reset per iteration)
for t in range(infer_len):
Expand All @@ -246,8 +248,28 @@ def run_mmp_factorized(
if t >= future_cutoff:
lnB_future = qs_T[f]
else:
future_msg = spm_dot(trans_B[f][...,int(policy[t, f])], qs_seq[t+1][B_factor_list[f]])
lnB_future = spm_log_single(future_msg)
# list of future_msgs, one for each of the factors that factor f is driving

B_marg_list = [] # list of the marginalized B matrices, that correspond to mapping between the factor of interest `f` and each of its children factors `i`
for i in inv_B_deps[f]: #loop over all the hidden state factors that are driven by f
b = B[i][...,int(policy[t,i])]
keep_dims = (0,1+B_factor_list[i].index(f))
dims = []
idxs = []
for j, d in enumerate(B_factor_list[i]): # loop over the list of factors that drive each child `i` of factor-of-interest `f` (i.e. the co-parents of `f`, with respect to child `i`)
if f != d:
dims.append((1 + j,))
idxs.append(d)
xs = [qs_seq[t+1][f_i] for f_i in idxs]
B_marg_list.append( factor_dot_flex(b, xs, tuple(dims), keep_dims=keep_dims) ) # marginalize out all parents of `i` besides `f`

lnB_future = np.zeros(num_states[f])
for i, b in enumerate(B_marg_list):
b_norm_T = spm_norm(b.T)
lnB_future += spm_log_single(b_norm_T.dot(qs_seq[t + 1][inv_B_deps[f][i]]))


lnB_future *= 0.5

# inference
if grad_descent:
Expand Down

0 comments on commit 4f4b7c4

Please sign in to comment.