From 4f4b7c4e730bc976961235810a73d5f2aaf519f9 Mon Sep 17 00:00:00 2001 From: conorheins Date: Wed, 5 Jun 2024 17:48:48 +0200 Subject: [PATCH] changed the backward messages in `run_mmp_factorized` to match how they are computed in `pymdp.jax.algos.get_mmp_messages` --- pymdp/algos/mmp.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/pymdp/algos/mmp.py b/pymdp/algos/mmp.py index 036e3ea3..019e81df 100644 --- a/pymdp/algos/mmp.py +++ b/pymdp/algos/mmp.py @@ -4,7 +4,7 @@ import numpy as np from pymdp.utils import to_obj_array, get_model_dimensions, obj_array, obj_array_zeros, obj_array_uniform -from pymdp.maths import spm_dot, spm_norm, softmax, calc_free_energy, spm_log_single +from pymdp.maths import spm_dot, spm_norm, softmax, calc_free_energy, spm_log_single, factor_dot_flex import copy def run_mmp( @@ -224,6 +224,8 @@ def run_mmp_factorized( joint_loglikelihood += lh_seq[t][m].reshape(reshape_dims) # add up all the log-likelihoods after reshaping them to the global common dimensions of all hidden state factors joint_lh_seq[t] = joint_loglikelihood + # compute inverse B dependencies, which is a list that for each hidden state factor, lists the indices of the other hidden state factors that it 'drives' or is a parent of in the HMM graphical model + inv_B_deps = [[i for i, d in enumerate(B_factor_list) if f in d] for f in range(num_factors)] for itr in range(num_iter): F = 0.0 # reset variational free energy (accumulated over time and factors, but reset per iteration) for t in range(infer_len): @@ -246,8 +248,28 @@ def run_mmp_factorized( if t >= future_cutoff: lnB_future = qs_T[f] else: - future_msg = spm_dot(trans_B[f][...,int(policy[t, f])], qs_seq[t+1][B_factor_list[f]]) - lnB_future = spm_log_single(future_msg) + # list of future_msgs, one for each of the factors that factor f is driving + + B_marg_list = [] # list of the marginalized B matrices, that correspond to mapping between the factor of interest `f` and each of its children factors `i` + for i in inv_B_deps[f]: #loop over all the hidden state factors that are driven by f + b = B[i][...,int(policy[t,i])] + keep_dims = (0,1+B_factor_list[i].index(f)) + dims = [] + idxs = [] + for j, d in enumerate(B_factor_list[i]): # loop over the list of factors that drive each child `i` of factor-of-interest `f` (i.e. the co-parents of `f`, with respect to child `i`) + if f != d: + dims.append((1 + j,)) + idxs.append(d) + xs = [qs_seq[t+1][f_i] for f_i in idxs] + B_marg_list.append( factor_dot_flex(b, xs, tuple(dims), keep_dims=keep_dims) ) # marginalize out all parents of `i` besides `f` + + lnB_future = np.zeros(num_states[f]) + for i, b in enumerate(B_marg_list): + b_norm_T = spm_norm(b.T) + lnB_future += spm_log_single(b_norm_T.dot(qs_seq[t + 1][inv_B_deps[f][i]])) + + + lnB_future *= 0.5 # inference if grad_descent: