diff --git a/pymdp/agent.py b/pymdp/agent.py index 9250c5bb..b28d26e9 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -60,11 +60,18 @@ def __init__( factors_to_learn="all", lr_pB=1.0, lr_pD=1.0, - use_BMA = True, + use_BMA=True, policy_sep_prior=False, save_belief_hist=False, A_factor_list=None, - B_factor_list=None + B_factor_list=None, + sophisticated=False, + si_horizon=3, + si_policy_prune_threshold=1/16, + si_state_prune_threshold=1/16, + si_prune_penalty=512, + ii_depth=10, + ii_threshold=1/16, ): ### Constant parameters ### @@ -86,6 +93,15 @@ def __init__( self.lr_pB = lr_pB self.lr_pD = lr_pD + # sophisticated inference parameters + self.sophisticated = sophisticated + if self.sophisticated: + assert self.policy_len == 1, "Sophisticated inference only works with policy_len = 1" + self.si_horizon = si_horizon + self.si_policy_prune_threshold = si_policy_prune_threshold + self.si_state_prune_threshold = si_state_prune_threshold + self.si_prune_penalty = si_prune_penalty + # Initialise observation model (A matrices) if not isinstance(A, np.ndarray): raise TypeError( @@ -129,6 +145,7 @@ def __init__( self.num_controls = num_controls # checking that `A_factor_list` and `B_factor_list` are consistent with `num_factors`, `num_states`, and lagging dimensions of `A` and `B` tensors + self.factorized = False if A_factor_list == None: self.A_factor_list = self.num_modalities * [list(range(self.num_factors))] # defaults to having all modalities depend on all factors for m in range(self.num_modalities): @@ -137,6 +154,7 @@ def __init__( if self.pA is not None: assert self.pA[m].shape[1:] == factor_dims, f"Please input an `A_factor_list` whose {m}-th indices pick out the hidden state factors that line up with lagging dimensions of pA{m}..." else: + self.factorized = True for m in range(self.num_modalities): assert max(A_factor_list[m]) <= (self.num_factors - 1), f"Check modality {m} of A_factor_list - must be consistent with `num_states` and `num_factors`..." factor_dims = tuple([self.num_states[f] for f in A_factor_list[m]]) @@ -164,6 +182,7 @@ def __init__( if self.pB is not None: assert self.pB[f].shape[1:-1] == factor_dims, f"Please input a `B_factor_list` whose {f}-th indices pick out the hidden state factors that line up with the all-but-final lagging dimensions of pB{f}..." else: + self.factorized = True for f in range(self.num_factors): assert max(B_factor_list[f]) <= (self.num_factors - 1), f"Check factor {f} of B_factor_list - must be consistent with `num_states` and `num_factors`..." factor_dims = tuple([self.num_states[f] for f in B_factor_list[f]]) @@ -186,7 +205,7 @@ def __init__( # Again, the use can specify a set of possible policies, or # all possible combinations of actions and timesteps will be considered - if policies == None: + if policies is None: policies = self._construct_policies() self.policies = policies @@ -251,8 +270,10 @@ def __init__( # Construct I for backwards induction (if H specified) if H is not None: - self.I = control.backwards_induction(H, B, B_factor_list, threshold=1/16, depth=5) + self.H = H + self.I = control.backwards_induction(H, B, B_factor_list, threshold=ii_threshold, depth=ii_depth) else: + self.H = None self.I = None self.edge_handling_params = {} @@ -616,6 +637,12 @@ def infer_policies_old(self): gamma=self.gamma ) elif self.inference_algo == "MMP": + if self.factorized: + raise NotImplementedError("Factorized inference not implemented for MMP") + + if self.sophisticated: + raise NotImplementedError("Sophisticated inference not implemented for MMP") + future_qs_seq = self.get_future_qs() @@ -664,23 +691,42 @@ def infer_policies(self): """ if self.inference_algo == "VANILLA": - q_pi, G = control.update_posterior_policies_factorized( - self.qs, - self.A, - self.B, - self.C, - self.A_factor_list, - self.B_factor_list, - self.policies, - self.use_utility, - self.use_states_info_gain, - self.use_param_info_gain, - self.pA, - self.pB, - E=self.E, - I=self.I, - gamma=self.gamma - ) + if self.sophisticated: + q_pi, G = control.sophisticated_inference_search( + self.qs, + self.policies, + self.A, + self.B, + self.C, + self.A_factor_list, + self.B_factor_list, + self.I, + self.si_horizon, + self.si_policy_prune_threshold, + self.si_state_prune_threshold, + self.si_prune_penalty, + 1.0, + self.inference_params, + n=0 + ) + else: + q_pi, G = control.update_posterior_policies_factorized( + self.qs, + self.A, + self.B, + self.C, + self.A_factor_list, + self.B_factor_list, + self.policies, + self.use_utility, + self.use_states_info_gain, + self.use_param_info_gain, + self.pA, + self.pB, + E = self.E, + I = self.I, + gamma = self.gamma + ) elif self.inference_algo == "MMP": future_qs_seq = self.get_future_qs() diff --git a/pymdp/control.py b/pymdp/control.py index 758427eb..ba7a218f 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -5,7 +5,8 @@ import itertools import numpy as np -from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, spm_log_obj_array +from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, kl_div, entropy +from pymdp.inference import update_posterior_states_factorized, average_states_over_policies from pymdp import utils import copy @@ -106,7 +107,7 @@ def update_posterior_policies_full( if I is not None: init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)] - qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E)) + qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E)) for p_idx, policy in enumerate(policies): @@ -236,7 +237,7 @@ def update_posterior_policies_full_factorized( if I is not None: init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)] - qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E)) + qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E)) for p_idx, policy in enumerate(policies): @@ -250,9 +251,9 @@ def update_posterior_policies_full_factorized( if use_param_info_gain: if pA is not None: - G[idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list) + G[p_idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list) if pB is not None: - G[idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs, B_factor_list, policy) + G[p_idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs_seq_pi[p_idx], B_factor_list, policy) if I is not None: G[p_idx] += calc_inductive_cost(qs_bma, qs_seq_pi[p_idx], I) @@ -943,7 +944,7 @@ def calc_inductive_cost(qs, qs_pi, I, epsilon=1e-3): m = np.where(I[factor][:, idx] == 1)[0] # we might find no path to goal (i.e. when no goal specified) if len(m) > 0: - m = np.max(m[0]-1, 0) + m = max(m[0]-1, 0) I_m = (1-I[factor][m, :]) * np.log(epsilon) inductive_cost += I_m.dot(qs_pi[t][factor]) @@ -1308,7 +1309,158 @@ def backwards_induction(H, B, B_factor_list, threshold, depth): for i in range(1, depth): I[factor][i, :] = np.dot(b, I[factor][i-1, :]) I[factor][i, :] = np.where(I[factor][i, :] > 0.1, 1.0, 0.0) + # TODO stop when all 1s? return I - +def calc_ambiguity_factorized(qs_pi, A, A_factor_list): + """ + Computes the Ambiguity term. + + Parameters + ---------- + qs_pi: ``list`` of ``numpy.ndarray`` of dtype object + Predictive posterior beliefs over hidden states expected under the policy, where ``qs_pi[t]`` stores the beliefs about + hidden states expected under the policy at time ``t`` + A: ``numpy.ndarray`` of dtype object + Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of + stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store + the probability of observation level ``i`` given hidden state levels ``j, k, ...`` + A_factor_list: ``list`` of ``list`` of ``int`` + List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on + + Returns + ------- + ambiguity: float + """ + + n_steps = len(qs_pi) + + ambiguity = 0 + # TODO check if we do this correctly! + H = entropy(A) + for t in range(n_steps): + for m, H_m in enumerate(H): + factor_idx = A_factor_list[m] + # TODO why does spm_dot return an array here? + # joint_x = maths.spm_cross(qs_pi[t][factor_idx]) + # ambiguity += (H_m * joint_x).sum() + ambiguity += np.sum(spm_dot(H_m, qs_pi[t][factor_idx])) + + return ambiguity + + +def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, + policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, + inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False}, n=0): + """ + Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences. + + Parameters + ---------- + qs: ``numpy.ndarray`` of dtype object + Marginal posterior beliefs over hidden states at a given timepoint. + policies: ``list`` of 1D ``numpy.ndarray`` inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False} + + ``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]`` + is ``(num_factors)`` where ``num_factors`` is the number of control factors. + A: ``numpy.ndarray`` of dtype object + Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of + stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store + the probability of observation level ``i`` given hidden state levels ``j, k, ...`` + B: ``numpy.ndarray`` of dtype object + Dynamics likelihood mapping or 'transition model', mapping from hidden states at ``t`` to hidden states at ``t+1``, given some control state ``u``. + Each element ``B[f]`` of this object array stores a 3-D tensor for hidden state factor ``f``, whose entries ``B[f][s, v, u]`` store the probability + of hidden state level ``s`` at the current time, given hidden state level ``v`` and action ``u`` at the previous time. + C: ``numpy.ndarray`` of dtype object + Prior over observations or 'prior preferences', storing the "value" of each outcome in terms of relative log probabilities. + This is softmaxed to form a proper probability distribution before being used to compute the expected utility term of the expected free energy. + A_factor_list: ``list`` of ``list`` of ``int`` + List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on + B_factor_list: ``list`` of ``list`` of ``int`` + List of lists of hidden state factors each hidden state factor depends on. Each element ``B_factor_list[i]`` is a list of the factor indices that factor i's dynamics depend on. + I: ``numpy.ndarray`` of dtype object + For each state factor, contains a 2D ``numpy.ndarray`` whose element i,j yields the probability + of reaching the goal state backwards from state j after i steps. + horizon: ``int`` + The temporal depth of the policy + policy_prune_threshold: ``float`` + The threshold for pruning policies that are below a certain probability + state_prune_threshold: ``float`` + The threshold for pruning states in the expectation that are below a certain probability + prune_penalty: ``float`` + Penalty to add to the EFE when a policy is pruned + gamma: ``float``, default 16.0 + Prior precision over policies, scales the contribution of the expected free energy to the posterior over policies + n: ``int`` + timestep in the future we are calculating + + Returns + ---------- + q_pi: 1D ``numpy.ndarray`` + Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. + + G: 1D ``numpy.ndarray`` + Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy. + """ + + n_policies = len(policies) + G = np.zeros(n_policies) + q_pi = np.zeros((n_policies, 1)) + qs_pi = utils.obj_array(n_policies) + qo_pi = utils.obj_array(n_policies) + + for idx, policy in enumerate(policies): + qs_pi[idx] = get_expected_states_interactions(qs, B, B_factor_list, policy) + qo_pi[idx] = get_expected_obs_factorized(qs_pi[idx], A, A_factor_list) + + G[idx] += calc_expected_utility(qo_pi[idx], C) + G[idx] += calc_states_info_gain_factorized(A, qs_pi[idx], A_factor_list) + + if I is not None: + G[idx] += calc_inductive_cost(qs, qs_pi[idx], I) + + q_pi = softmax(G * gamma) + + if n < horizon - 1: + # ignore low probability actions in the search tree + # TODO shouldnt we have to add extra penalty for branches no longer considered? + # or assume these are already low EFE (high NEFE) anyway? + policies_to_consider = list(np.where(q_pi >= policy_prune_threshold)[0]) + for idx in range(n_policies): + if idx not in policies_to_consider: + G[idx] -= prune_penalty + else : + # average over outcomes + qo_next = qo_pi[idx][0] + for k in itertools.product(*[range(s.shape[0]) for s in qo_next]): + prob = 1.0 + for i in range(len(k)): + prob *= qo_pi[idx][0][i][k[i]] + + # ignore low probability states in the search tree + if prob < state_prune_threshold: + continue + + qo_one_hot = utils.obj_array(len(qo_next)) + for i in range(len(qo_one_hot)): + qo_one_hot[i] = utils.onehot(k[i], qo_next[i].shape[0]) + + num_obs = [A[m].shape[0] for m in range(len(A))] + num_states = [B[f].shape[0] for f in range(len(B))] + A_modality_list = [] + for f in range(len(B)): + A_modality_list.append( [m for m in range(len(A)) if f in A_factor_list[m]] ) + mb_dict = { + 'A_factor_list': A_factor_list, + 'A_modality_list': A_modality_list + } + qs_next = update_posterior_states_factorized(A, qo_one_hot, num_obs, num_states, mb_dict, qs_pi[idx][0], **inference_params) + q_pi_next, G_next = sophisticated_inference_search(qs_next, policies, A, B, C, A_factor_list, B_factor_list, I, + horizon, policy_prune_threshold, state_prune_threshold, + prune_penalty, gamma, inference_params, n+1) + G_weighted = np.dot(q_pi_next, G_next) * prob + G[idx] += G_weighted + + q_pi = softmax(G * gamma) + return q_pi, G \ No newline at end of file diff --git a/pymdp/maths.py b/pymdp/maths.py index 68000fb4..1d5e9e4d 100644 --- a/pymdp/maths.py +++ b/pymdp/maths.py @@ -572,3 +572,37 @@ def spm_MDP_G(A, x): return G +def kl_div(P,Q): + """ + Parameters + ---------- + P : Categorical probability distribution + Q : Categorical probability distribution + + Returns + ------- + The KL-divergence of P and Q + + """ + dkl = 0 + for i in range(len(P)): + dkl += np.dot(P[i], np.log(P[i] + EPS_VAL) - np.log(Q[i] + EPS_VAL)) + return(dkl) + +def entropy(A): + """ + Compute the entropy term H of the likelihood matrix, + i.e. one entropy value per column + """ + entropies = np.empty(len(A), dtype=object) + for i in range(len(A)): + if len(A[i].shape) > 2: + obs_dim = A[i].shape[0] + s_dim = A[i].size // obs_dim + A_merged = A[i].reshape(obs_dim, s_dim) + else: + A_merged = A[i] + + H = - np.diag(np.matmul(A_merged.T, np.log(A_merged + EPS_VAL))) + entropies[i] = H.reshape(*A[i].shape[1:]) + return entropies \ No newline at end of file diff --git a/test/test_wrappers.py b/test/test_wrappers.py index 86ff5996..cf405e56 100644 --- a/test/test_wrappers.py +++ b/test/test_wrappers.py @@ -1,17 +1,8 @@ import os import unittest from pathlib import Path - -import pandas as pd -from pandas.testing import assert_frame_equal - from pymdp.utils import Dimensions, get_model_dimensions_from_labels -tmp_path = Path('tmp_dir') - -if not os.path.isdir(tmp_path): - os.mkdir(tmp_path) - class TestWrappers(unittest.TestCase): def test_get_model_dimensions_from_labels(self):