From 2fc1c6943bc048706e7b3ffc5141830686ecdfa7 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Tue, 21 Nov 2023 08:16:41 +0100 Subject: [PATCH 01/12] add sophisticated policy search method --- pymdp/control.py | 137 ++++++++++++++++++++++++++++++++++++++++++++++- pymdp/maths.py | 34 ++++++++++++ 2 files changed, 170 insertions(+), 1 deletion(-) diff --git a/pymdp/control.py b/pymdp/control.py index 892c02f3..d3fbc463 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -5,7 +5,7 @@ import itertools import numpy as np -from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, spm_log_obj_array +from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, kl_div, entropy from pymdp import utils import copy @@ -1310,5 +1310,140 @@ def backwards_induction(H, B, B_factor_list, threshold, depth): I[factor][i, :] = np.where(I[factor][i, :] > 0.1, 1.0, 0.0) return I + +def calc_ambiguity_factorized(qs_pi, A, A_factor_list): + """ + Computes the Ambiguity term. + + Parameters + ---------- + qs_pi: ``list`` of ``numpy.ndarray`` of dtype object + Predictive posterior beliefs over hidden states expected under the policy, where ``qs_pi[t]`` stores the beliefs about + hidden states expected under the policy at time ``t`` + A: ``numpy.ndarray`` of dtype object + Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of + stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store + the probability of observation level ``i`` given hidden state levels ``j, k, ...`` + A_factor_list: ``list`` of ``list`` of ``int`` + List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on + + Returns + ------- + ambiguity: float + """ + + n_steps = len(qs_pi) + + ambiguity = 0 + # TODO check if we do this correctly! + H = entropy(A) + for t in range(n_steps): + for m, H_m in enumerate(H): + factor_idx = A_factor_list[m] + # TODO why does spm_dot return an array here? + # joint_x = maths.spm_cross(qs_pi[t][factor_idx]) + # ambiguity += (H_m * joint_x).sum() + ambiguity += np.sum(spm_dot(H_m, qs_pi[t][factor_idx])) + + return ambiguity +def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, + policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, n=0): + """ + Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences. + Parameters + ---------- + qs: ``numpy.ndarray`` of dtype object + Marginal posterior beliefs over hidden states at a given timepoint. + policies: ``list`` of 1D ``numpy.ndarray`` + ``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]`` + is ``(num_factors)`` where ``num_factors`` is the number of control factors. + A: ``numpy.ndarray`` of dtype object + Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of + stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store + the probability of observation level ``i`` given hidden state levels ``j, k, ...`` + B: ``numpy.ndarray`` of dtype object + Dynamics likelihood mapping or 'transition model', mapping from hidden states at ``t`` to hidden states at ``t+1``, given some control state ``u``. + Each element ``B[f]`` of this object array stores a 3-D tensor for hidden state factor ``f``, whose entries ``B[f][s, v, u]`` store the probability + of hidden state level ``s`` at the current time, given hidden state level ``v`` and action ``u`` at the previous time. + C: ``numpy.ndarray`` of dtype object + Prior over observations or 'prior preferences', storing the "value" of each outcome in terms of relative log probabilities. + This is softmaxed to form a proper probability distribution before being used to compute the expected utility term of the expected free energy. + A_factor_list: ``list`` of ``list`` of ``int`` + List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on + B_factor_list: ``list`` of ``list`` of ``int`` + List of lists of hidden state factors each hidden state factor depends on. Each element ``B_factor_list[i]`` is a list of the factor indices that factor i's dynamics depend on. + I: ``numpy.ndarray`` of dtype object + For each state factor, contains a 2D ``numpy.ndarray`` whose element i,j yields the probability + of reaching the goal state backwards from state j after i steps. + horizon: ``int`` + The temporal depth of the policy + policy_prune_threshold: ``float`` + The threshold for pruning policies that are below a certain probability + state_prune_threshold: ``float`` + The threshold for pruning states in the expectation that are below a certain probability + prune_penalty: ``float`` + Penalty to add to the EFE when a policy is pruned + gamma: ``float``, default 16.0 + Prior precision over policies, scales the contribution of the expected free energy to the posterior over policies + n: ``int`` + timestep in the future we are calculating + + Returns + ---------- + q_pi: 1D ``numpy.ndarray`` + Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. + + G: 1D ``numpy.ndarray`` + Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy. + """ + + n_policies = len(policies) + G = np.zeros(n_policies) + q_pi = np.zeros((n_policies, 1)) + qs_pi = utils.obj_array(n_policies) + + for idx, policy in enumerate(policies): + qs_pi[idx] = get_expected_states_interactions(qs, B, B_factor_list, policy) + qo_pi = get_expected_obs_factorized(qs_pi[idx], A, A_factor_list) + + C_prob = softmax_obj_arr(C) + G[idx] += -kl_div(qo_pi[0], C_prob) + G[idx] += -calc_ambiguity_factorized(qs_pi[idx], A, A_factor_list) + if I is not None: + G[idx] += calc_inductive_cost(qs, qs_pi[idx], I) + + q_pi = softmax(G * gamma) + + if n < horizon - 1: + # ignore low probability actions in the search tree + # TODO shouldnt we have to add extra penalty for branches no longer considered? + # or assume these are already low EFE (high NEFE) anyway? + policies_to_consider = list(np.where(q_pi >= policy_prune_threshold)[0]) + for idx in range(n_policies): + if idx not in policies_to_consider: + G[idx] -= prune_penalty + else : + # average over states + qs_next = qs_pi[idx][0] + for k in itertools.product(*[range(s.shape[0]) for s in qs_next]): + prob = 1.0 + for i in range(len(k)): + prob *= qs_pi[idx][0][i][k[i]] + + # ignore low probability states in the search tree + if prob < state_prune_threshold: + continue + + qs_one_hot = utils.obj_array(len(qs)) + for i in range(len(qs)): + qs_one_hot[i] = utils.onehot(k[i], qs_next[i].shape[0]) + + q_pi_next, G_next = sophisticated_inference_search(qs_one_hot, policies, A, B, C, A_factor_list, B_factor_list, I, + horizon, policy_prune_threshold, state_prune_threshold, n=n+1) + G_weighted = np.dot(q_pi_next, G_next) * prob + G[idx] += G_weighted + + q_pi = softmax(G * gamma) + return q_pi, G \ No newline at end of file diff --git a/pymdp/maths.py b/pymdp/maths.py index 6f2fd3b8..7b2eacdb 100644 --- a/pymdp/maths.py +++ b/pymdp/maths.py @@ -552,3 +552,37 @@ def spm_MDP_G(A, x): return G +def kl_div(P,Q): + """ + Parameters + ---------- + P : Categorical probability distribution + Q : Categorical probability distribution + + Returns + ------- + The KL-divergence of P and Q + + """ + dkl = 0 + for i in range(len(P)): + dkl += np.dot(P[i], np.log(P[i] + EPS_VAL) - np.log(Q[i] + EPS_VAL)) + return(dkl) + +def entropy(A): + """ + Compute the entropy term H of the likelihood matrix, + i.e. one entropy value per column + """ + entropies = np.empty(len(A), dtype=object) + for i in range(len(A)): + if len(A[i].shape) > 2: + obs_dim = A[i].shape[0] + s_dim = A[i].size // obs_dim + A_merged = A[i].reshape(obs_dim, s_dim) + else: + A_merged = A[i] + + H = - np.diag(np.matmul(A_merged.T, np.log(A_merged + EPS_VAL))) + entropies[i] = H.reshape(*A[i].shape[1:]) + return entropies \ No newline at end of file From 6672861bb343d3476542edc1526766a81106efcd Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Tue, 21 Nov 2023 08:17:06 +0100 Subject: [PATCH 02/12] fix is None check --- pymdp/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index 75bafc46..ce5f9dad 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -186,7 +186,7 @@ def __init__( # Again, the use can specify a set of possible policies, or # all possible combinations of actions and timesteps will be considered - if policies == None: + if policies is None: policies = self._construct_policies() self.policies = policies From 417482b1004f47d9a9fb43eedd3684e11d16a9cc Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Tue, 21 Nov 2023 08:21:35 +0100 Subject: [PATCH 03/12] only keep infer_policies function in Agent, which calls factorized if implemented --- pymdp/agent.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index ce5f9dad..e8c5687c 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -129,6 +129,7 @@ def __init__( self.num_controls = num_controls # checking that `A_factor_list` and `B_factor_list` are consistent with `num_factors`, `num_states`, and lagging dimensions of `A` and `B` tensors + self.factorized = False if A_factor_list == None: self.A_factor_list = self.num_modalities * [list(range(self.num_factors))] # defaults to having all modalities depend on all factors for m in range(self.num_modalities): @@ -137,6 +138,7 @@ def __init__( if self.pA != None: assert self.pA[m].shape[1:] == factor_dims, f"Please input an `A_factor_list` whose {m}-th indices pick out the hidden state factors that line up with lagging dimensions of pA{m}..." else: + self.factorized = True for m in range(self.num_modalities): assert max(A_factor_list[m]) <= (self.num_factors - 1), f"Check modality {m} of A_factor_list - must be consistent with `num_states` and `num_factors`..." factor_dims = tuple([self.num_states[f] for f in A_factor_list[m]]) @@ -164,6 +166,7 @@ def __init__( if self.pB != None: assert self.pB[f].shape[1:-1] == factor_dims, f"Please input a `B_factor_list` whose {f}-th indices pick out the hidden state factors that line up with the all-but-final lagging dimensions of pB{f}..." else: + self.factorized = True for f in range(self.num_factors): assert max(B_factor_list[f]) <= (self.num_factors - 1), f"Check factor {f} of B_factor_list - must be consistent with `num_states` and `num_factors`..." factor_dims = tuple([self.num_states[f] for f in B_factor_list[f]]) @@ -600,22 +603,26 @@ def infer_policies_old(self): """ if self.inference_algo == "VANILLA": - q_pi, G = control.update_posterior_policies( + q_pi, G = control.update_posterior_policies_factorized( self.qs, self.A, self.B, self.C, + self.A_factor_list, + self.B_factor_list, self.policies, self.use_utility, self.use_states_info_gain, self.use_param_info_gain, self.pA, self.pB, - E=self.E, - I=self.I, - gamma=self.gamma + E = self.E, + I = self.I, + gamma = self.gamma ) elif self.inference_algo == "MMP": + if self.factorized: + raise NotImplementedError("Factorized inference not implemented for MMP") future_qs_seq = self.get_future_qs() From a4546f3c75e10150a2cffb56d47b9e836eeaa733 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Tue, 21 Nov 2023 08:32:13 +0100 Subject: [PATCH 04/12] add sophisticated inference as a flag for pymdp Agent --- pymdp/agent.py | 84 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 25 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index e8c5687c..b9cb70c6 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -60,11 +60,16 @@ def __init__( factors_to_learn="all", lr_pB=1.0, lr_pD=1.0, - use_BMA = True, + use_BMA=True, policy_sep_prior=False, save_belief_hist=False, A_factor_list=None, - B_factor_list=None + B_factor_list=None, + sophisticated=False, + si_horizon=3, + si_policy_prune_threshold=1/16, + si_state_prune_threshold=1/16, + si_prune_penalty=512 ): ### Constant parameters ### @@ -86,6 +91,15 @@ def __init__( self.lr_pB = lr_pB self.lr_pD = lr_pD + # sophisticated inference parameters + self.sophisticated = sophisticated + if self.sophisticated: + assert self.policy_len == 1, "Sophisticated inference only works with policy_len = 1" + self.si_horizon = si_horizon + self.si_policy_prune_threshold = si_policy_prune_threshold + self.si_state_prune_threshold = si_state_prune_threshold + self.si_prune_penalty = si_prune_penalty + # Initialise observation model (A matrices) if not isinstance(A, np.ndarray): raise TypeError( @@ -603,26 +617,28 @@ def infer_policies_old(self): """ if self.inference_algo == "VANILLA": - q_pi, G = control.update_posterior_policies_factorized( + q_pi, G = control.update_posterior_policies( self.qs, self.A, self.B, self.C, - self.A_factor_list, - self.B_factor_list, self.policies, self.use_utility, self.use_states_info_gain, self.use_param_info_gain, self.pA, self.pB, - E = self.E, - I = self.I, - gamma = self.gamma + E=self.E, + I=self.I, + gamma=self.gamma ) elif self.inference_algo == "MMP": if self.factorized: raise NotImplementedError("Factorized inference not implemented for MMP") + + if self.sophisticated: + raise NotImplementedError("Sophisticated inference not implemented for MMP") + future_qs_seq = self.get_future_qs() @@ -671,23 +687,41 @@ def infer_policies(self): """ if self.inference_algo == "VANILLA": - q_pi, G = control.update_posterior_policies_factorized( - self.qs, - self.A, - self.B, - self.C, - self.A_factor_list, - self.B_factor_list, - self.policies, - self.use_utility, - self.use_states_info_gain, - self.use_param_info_gain, - self.pA, - self.pB, - E=self.E, - I=self.I, - gamma=self.gamma - ) + if self.sophisticated: + q_pi, G = control.sophisticated_inference_search( + self.qs, + self.policies, + self.A, + self.B, + self.C, + self.A_factor_list, + self.B_factor_list, + self.I, + self.si_horizon, + self.si_policy_prune_threshold, + self.si_state_prune_threshold, + self.si_prune_penalty, + self.gamma, + n=0 + ) + else: + q_pi, G = control.update_posterior_policies_factorized( + self.qs, + self.A, + self.B, + self.C, + self.A_factor_list, + self.B_factor_list, + self.policies, + self.use_utility, + self.use_states_info_gain, + self.use_param_info_gain, + self.pA, + self.pB, + E = self.E, + I = self.I, + gamma = self.gamma + ) elif self.inference_algo == "MMP": future_qs_seq = self.get_future_qs() From b52903a22f3469f474bd2fdddf6427b34d8f9ff2 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Wed, 22 Nov 2023 11:48:26 +0100 Subject: [PATCH 05/12] implement si by explicitly branching observations --- pymdp/agent.py | 31 ++++++++++++- pymdp/control.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index b9cb70c6..7ed644a2 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -268,8 +268,10 @@ def __init__( # Construct I for backwards induction (if H specified) if H is not None: + self.H = H self.I = control.backwards_induction(H, B, B_factor_list, threshold=1/16, depth=5) else: + self.H = None self.I = None self.edge_handling_params = {} @@ -688,7 +690,28 @@ def infer_policies(self): if self.inference_algo == "VANILLA": if self.sophisticated: - q_pi, G = control.sophisticated_inference_search( + # q_pi, G = control.sophisticated_inference_search( + # self.qs, + # self.policies, + # self.A, + # self.B, + # self.C, + # self.A_factor_list, + # self.B_factor_list, + # self.I, + # self.si_horizon, + # self.si_policy_prune_threshold, + # self.si_state_prune_threshold, + # self.si_prune_penalty, + # 1.0, + # n=0 + # ) + + # print("Sophisticated 1") + # for i in range(len(self.policies)): + # print(G[i], [(p[0], p[1]) for p in self.policies[i]]) + + q_pi, G = control.sophisticated_inference_search2( self.qs, self.policies, self.A, @@ -701,9 +724,13 @@ def infer_policies(self): self.si_policy_prune_threshold, self.si_state_prune_threshold, self.si_prune_penalty, - self.gamma, + 1.0, n=0 ) + + # print("Sophisticated 2") + # for i in range(len(self.policies)): + # print(G[i], [(p[0], p[1]) for p in self.policies[i]]) else: q_pi, G = control.update_posterior_policies_factorized( self.qs, diff --git a/pymdp/control.py b/pymdp/control.py index d3fbc463..993b0c8f 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -6,6 +6,7 @@ import itertools import numpy as np from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, kl_div, entropy +from pymdp.inference import update_posterior_states_factorized from pymdp import utils import copy @@ -1414,7 +1415,7 @@ def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_facto if I is not None: G[idx] += calc_inductive_cost(qs, qs_pi[idx], I) - q_pi = softmax(G * gamma) + q_pi = softmax(G * gamma) if n < horizon - 1: # ignore low probability actions in the search tree @@ -1445,5 +1446,119 @@ def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_facto G_weighted = np.dot(q_pi_next, G_next) * prob G[idx] += G_weighted + q_pi = softmax(G * gamma) + return q_pi, G + + +def sophisticated_inference_search2(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, + policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, n=0): + """ + Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences. + + Parameters + ---------- + qs: ``numpy.ndarray`` of dtype object + Marginal posterior beliefs over hidden states at a given timepoint. + policies: ``list`` of 1D ``numpy.ndarray`` + ``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]`` + is ``(num_factors)`` where ``num_factors`` is the number of control factors. + A: ``numpy.ndarray`` of dtype object + Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of + stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store + the probability of observation level ``i`` given hidden state levels ``j, k, ...`` + B: ``numpy.ndarray`` of dtype object + Dynamics likelihood mapping or 'transition model', mapping from hidden states at ``t`` to hidden states at ``t+1``, given some control state ``u``. + Each element ``B[f]`` of this object array stores a 3-D tensor for hidden state factor ``f``, whose entries ``B[f][s, v, u]`` store the probability + of hidden state level ``s`` at the current time, given hidden state level ``v`` and action ``u`` at the previous time. + C: ``numpy.ndarray`` of dtype object + Prior over observations or 'prior preferences', storing the "value" of each outcome in terms of relative log probabilities. + This is softmaxed to form a proper probability distribution before being used to compute the expected utility term of the expected free energy. + A_factor_list: ``list`` of ``list`` of ``int`` + List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on + B_factor_list: ``list`` of ``list`` of ``int`` + List of lists of hidden state factors each hidden state factor depends on. Each element ``B_factor_list[i]`` is a list of the factor indices that factor i's dynamics depend on. + I: ``numpy.ndarray`` of dtype object + For each state factor, contains a 2D ``numpy.ndarray`` whose element i,j yields the probability + of reaching the goal state backwards from state j after i steps. + horizon: ``int`` + The temporal depth of the policy + policy_prune_threshold: ``float`` + The threshold for pruning policies that are below a certain probability + state_prune_threshold: ``float`` + The threshold for pruning states in the expectation that are below a certain probability + prune_penalty: ``float`` + Penalty to add to the EFE when a policy is pruned + gamma: ``float``, default 16.0 + Prior precision over policies, scales the contribution of the expected free energy to the posterior over policies + n: ``int`` + timestep in the future we are calculating + + Returns + ---------- + q_pi: 1D ``numpy.ndarray`` + Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. + + G: 1D ``numpy.ndarray`` + Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy. + """ + + n_policies = len(policies) + G = np.zeros(n_policies) + q_pi = np.zeros((n_policies, 1)) + qs_pi = utils.obj_array(n_policies) + qo_pi = utils.obj_array(n_policies) + + for idx, policy in enumerate(policies): + qs_pi[idx] = get_expected_states_interactions(qs, B, B_factor_list, policy) + qo_pi[idx] = get_expected_obs_factorized(qs_pi[idx], A, A_factor_list) + + G[idx] += calc_expected_utility(qo_pi[idx], C) + G[idx] += calc_states_info_gain_factorized(A, qs_pi[idx], A_factor_list) + + if I is not None: + G[idx] += calc_inductive_cost(qs, qs_pi[idx], I) + + q_pi = softmax(G * gamma) + + if n < horizon - 1: + # ignore low probability actions in the search tree + # TODO shouldnt we have to add extra penalty for branches no longer considered? + # or assume these are already low EFE (high NEFE) anyway? + policies_to_consider = list(np.where(q_pi >= policy_prune_threshold)[0]) + for idx in range(n_policies): + if idx not in policies_to_consider: + G[idx] -= prune_penalty + else : + # average over outcomes + qo_next = qo_pi[idx][0] + for k in itertools.product(*[range(s.shape[0]) for s in qo_next]): + prob = 1.0 + for i in range(len(k)): + prob *= qo_pi[idx][0][i][k[i]] + + # ignore low probability states in the search tree + if prob < state_prune_threshold: + continue + + qo_one_hot = utils.obj_array(len(qo_next)) + for i in range(len(qo_one_hot)): + qo_one_hot[i] = utils.onehot(k[i], qo_next[i].shape[0]) + + num_obs = [A[m].shape[0] for m in range(len(A))] + num_states = [B[f].shape[0] for f in range(len(B))] + A_modality_list = [] + for f in range(len(B)): + A_modality_list.append( [m for m in range(len(A)) if f in A_factor_list[m]] ) + mb_dict = { + 'A_factor_list': A_factor_list, + 'A_modality_list': A_modality_list + } + inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False} + qs_next = update_posterior_states_factorized(A, qo_one_hot, num_obs, num_states, mb_dict, qs_pi[idx][0], **inference_params) + q_pi_next, G_next = sophisticated_inference_search2(qs_next, policies, A, B, C, A_factor_list, B_factor_list, I, + horizon, policy_prune_threshold, state_prune_threshold, n=n+1) + G_weighted = np.dot(q_pi_next, G_next) * prob + G[idx] += G_weighted + q_pi = softmax(G * gamma) return q_pi, G \ No newline at end of file From e9357ac70c6fe9ab8ab5dcf6ca68da35382b7d97 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Mon, 4 Dec 2023 08:38:56 +0100 Subject: [PATCH 06/12] expose parameters for inductive inference --- pymdp/agent.py | 6 ++++-- pymdp/control.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index 7ed644a2..078d73fe 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -69,7 +69,9 @@ def __init__( si_horizon=3, si_policy_prune_threshold=1/16, si_state_prune_threshold=1/16, - si_prune_penalty=512 + si_prune_penalty=512, + ii_depth=10, + ii_threshold=1/16, ): ### Constant parameters ### @@ -269,7 +271,7 @@ def __init__( # Construct I for backwards induction (if H specified) if H is not None: self.H = H - self.I = control.backwards_induction(H, B, B_factor_list, threshold=1/16, depth=5) + self.I = control.backwards_induction(H, B, B_factor_list, threshold=ii_threshold, depth=ii_depth) else: self.H = None self.I = None diff --git a/pymdp/control.py b/pymdp/control.py index 993b0c8f..1e6a2fc3 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -1309,6 +1309,7 @@ def backwards_induction(H, B, B_factor_list, threshold, depth): for i in range(1, depth): I[factor][i, :] = np.dot(b, I[factor][i-1, :]) I[factor][i, :] = np.where(I[factor][i, :] > 0.1, 1.0, 0.0) + # TODO stop when all 1s? return I From 3591c76ef23a1c39abef7dd187600b8514e23c3f Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Fri, 12 Jan 2024 10:41:51 +0100 Subject: [PATCH 07/12] cleanup si a bit --- pymdp/agent.py | 28 +----------- pymdp/control.py | 115 ++++------------------------------------------- 2 files changed, 10 insertions(+), 133 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index 078d73fe..c7748bde 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -692,28 +692,7 @@ def infer_policies(self): if self.inference_algo == "VANILLA": if self.sophisticated: - # q_pi, G = control.sophisticated_inference_search( - # self.qs, - # self.policies, - # self.A, - # self.B, - # self.C, - # self.A_factor_list, - # self.B_factor_list, - # self.I, - # self.si_horizon, - # self.si_policy_prune_threshold, - # self.si_state_prune_threshold, - # self.si_prune_penalty, - # 1.0, - # n=0 - # ) - - # print("Sophisticated 1") - # for i in range(len(self.policies)): - # print(G[i], [(p[0], p[1]) for p in self.policies[i]]) - - q_pi, G = control.sophisticated_inference_search2( + q_pi, G = control.sophisticated_inference_search( self.qs, self.policies, self.A, @@ -727,12 +706,9 @@ def infer_policies(self): self.si_state_prune_threshold, self.si_prune_penalty, 1.0, + self.inference_params, n=0 ) - - # print("Sophisticated 2") - # for i in range(len(self.policies)): - # print(G[i], [(p[0], p[1]) for p in self.policies[i]]) else: q_pi, G = control.update_posterior_policies_factorized( self.qs, diff --git a/pymdp/control.py b/pymdp/control.py index 1e6a2fc3..6a7231ce 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -1348,111 +1348,11 @@ def calc_ambiguity_factorized(qs_pi, A, A_factor_list): ambiguity += np.sum(spm_dot(H_m, qs_pi[t][factor_idx])) return ambiguity - -def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, - policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, n=0): - """ - Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences. - - Parameters - ---------- - qs: ``numpy.ndarray`` of dtype object - Marginal posterior beliefs over hidden states at a given timepoint. - policies: ``list`` of 1D ``numpy.ndarray`` - ``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]`` - is ``(num_factors)`` where ``num_factors`` is the number of control factors. - A: ``numpy.ndarray`` of dtype object - Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of - stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store - the probability of observation level ``i`` given hidden state levels ``j, k, ...`` - B: ``numpy.ndarray`` of dtype object - Dynamics likelihood mapping or 'transition model', mapping from hidden states at ``t`` to hidden states at ``t+1``, given some control state ``u``. - Each element ``B[f]`` of this object array stores a 3-D tensor for hidden state factor ``f``, whose entries ``B[f][s, v, u]`` store the probability - of hidden state level ``s`` at the current time, given hidden state level ``v`` and action ``u`` at the previous time. - C: ``numpy.ndarray`` of dtype object - Prior over observations or 'prior preferences', storing the "value" of each outcome in terms of relative log probabilities. - This is softmaxed to form a proper probability distribution before being used to compute the expected utility term of the expected free energy. - A_factor_list: ``list`` of ``list`` of ``int`` - List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on - B_factor_list: ``list`` of ``list`` of ``int`` - List of lists of hidden state factors each hidden state factor depends on. Each element ``B_factor_list[i]`` is a list of the factor indices that factor i's dynamics depend on. - I: ``numpy.ndarray`` of dtype object - For each state factor, contains a 2D ``numpy.ndarray`` whose element i,j yields the probability - of reaching the goal state backwards from state j after i steps. - horizon: ``int`` - The temporal depth of the policy - policy_prune_threshold: ``float`` - The threshold for pruning policies that are below a certain probability - state_prune_threshold: ``float`` - The threshold for pruning states in the expectation that are below a certain probability - prune_penalty: ``float`` - Penalty to add to the EFE when a policy is pruned - gamma: ``float``, default 16.0 - Prior precision over policies, scales the contribution of the expected free energy to the posterior over policies - n: ``int`` - timestep in the future we are calculating - - Returns - ---------- - q_pi: 1D ``numpy.ndarray`` - Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy. - - G: 1D ``numpy.ndarray`` - Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy. - """ - - n_policies = len(policies) - G = np.zeros(n_policies) - q_pi = np.zeros((n_policies, 1)) - qs_pi = utils.obj_array(n_policies) - - for idx, policy in enumerate(policies): - qs_pi[idx] = get_expected_states_interactions(qs, B, B_factor_list, policy) - qo_pi = get_expected_obs_factorized(qs_pi[idx], A, A_factor_list) - - C_prob = softmax_obj_arr(C) - G[idx] += -kl_div(qo_pi[0], C_prob) - G[idx] += -calc_ambiguity_factorized(qs_pi[idx], A, A_factor_list) - if I is not None: - G[idx] += calc_inductive_cost(qs, qs_pi[idx], I) - - q_pi = softmax(G * gamma) - - if n < horizon - 1: - # ignore low probability actions in the search tree - # TODO shouldnt we have to add extra penalty for branches no longer considered? - # or assume these are already low EFE (high NEFE) anyway? - policies_to_consider = list(np.where(q_pi >= policy_prune_threshold)[0]) - for idx in range(n_policies): - if idx not in policies_to_consider: - G[idx] -= prune_penalty - else : - # average over states - qs_next = qs_pi[idx][0] - for k in itertools.product(*[range(s.shape[0]) for s in qs_next]): - prob = 1.0 - for i in range(len(k)): - prob *= qs_pi[idx][0][i][k[i]] - - # ignore low probability states in the search tree - if prob < state_prune_threshold: - continue - qs_one_hot = utils.obj_array(len(qs)) - for i in range(len(qs)): - qs_one_hot[i] = utils.onehot(k[i], qs_next[i].shape[0]) - - q_pi_next, G_next = sophisticated_inference_search(qs_one_hot, policies, A, B, C, A_factor_list, B_factor_list, I, - horizon, policy_prune_threshold, state_prune_threshold, n=n+1) - G_weighted = np.dot(q_pi_next, G_next) * prob - G[idx] += G_weighted - q_pi = softmax(G * gamma) - return q_pi, G - - -def sophisticated_inference_search2(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, - policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, n=0): +def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1, + policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16, + inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False}, n=0): """ Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences. @@ -1460,7 +1360,8 @@ def sophisticated_inference_search2(qs, policies, A, B, C, A_factor_list, B_fact ---------- qs: ``numpy.ndarray`` of dtype object Marginal posterior beliefs over hidden states at a given timepoint. - policies: ``list`` of 1D ``numpy.ndarray`` + policies: ``list`` of 1D ``numpy.ndarray`` inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False} + ``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]`` is ``(num_factors)`` where ``num_factors`` is the number of control factors. A: ``numpy.ndarray`` of dtype object @@ -1554,10 +1455,10 @@ def sophisticated_inference_search2(qs, policies, A, B, C, A_factor_list, B_fact 'A_factor_list': A_factor_list, 'A_modality_list': A_modality_list } - inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False} qs_next = update_posterior_states_factorized(A, qo_one_hot, num_obs, num_states, mb_dict, qs_pi[idx][0], **inference_params) - q_pi_next, G_next = sophisticated_inference_search2(qs_next, policies, A, B, C, A_factor_list, B_factor_list, I, - horizon, policy_prune_threshold, state_prune_threshold, n=n+1) + q_pi_next, G_next = sophisticated_inference_search(qs_next, policies, A, B, C, A_factor_list, B_factor_list, I, + horizon, policy_prune_threshold, state_prune_threshold, + prune_penalty, gamma, inference_params, n+1) G_weighted = np.dot(q_pi_next, G_next) * prob G[idx] += G_weighted From 7200f3f5ac78f047de7d3761f627e87b51b83cb9 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Fri, 12 Jan 2024 11:40:47 +0100 Subject: [PATCH 08/12] fix index in update_posterior_policies_full_factorized --- pymdp/control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymdp/control.py b/pymdp/control.py index 6a7231ce..3489ce03 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -251,9 +251,9 @@ def update_posterior_policies_full_factorized( if use_param_info_gain: if pA is not None: - G[idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list) + G[p_idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list) if pB is not None: - G[idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs, B_factor_list, policy) + G[p_idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs_seq_pi[p_idx], B_factor_list, policy) if I is not None: G[p_idx] += calc_inductive_cost(qs_bma, qs_seq_pi[p_idx], I) From 9e5c366bd5af2b58ed66c903218613596c6a9426 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Fri, 12 Jan 2024 11:41:55 +0100 Subject: [PATCH 09/12] fix import average_states_over_policies --- pymdp/control.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymdp/control.py b/pymdp/control.py index 3489ce03..c5497964 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -6,7 +6,7 @@ import itertools import numpy as np from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, kl_div, entropy -from pymdp.inference import update_posterior_states_factorized +from pymdp.inference import update_posterior_states_factorized, average_states_over_policies from pymdp import utils import copy @@ -107,7 +107,7 @@ def update_posterior_policies_full( if I is not None: init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)] - qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E)) + qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E)) for p_idx, policy in enumerate(policies): @@ -237,7 +237,7 @@ def update_posterior_policies_full_factorized( if I is not None: init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)] - qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E)) + qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E)) for p_idx, policy in enumerate(policies): From a9874f10f1ebbbe38f84170776f031aa7b711673 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Fri, 3 May 2024 12:44:05 +0200 Subject: [PATCH 10/12] fix max --- pymdp/control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymdp/control.py b/pymdp/control.py index c5497964..a2379a7c 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -944,7 +944,7 @@ def calc_inductive_cost(qs, qs_pi, I, epsilon=1e-3): m = np.where(I[factor][:, idx] == 1)[0] # we might find no path to goal (i.e. when no goal specified) if len(m) > 0: - m = np.max(m[0]-1, 0) + m = max(m[0]-1, 0) I_m = (1-I[factor][m, :]) * np.log(epsilon) inductive_cost += I_m.dot(qs_pi[t][factor]) From b59cfcd2e0974ce6ce0e6da2a27f7b07726eafd6 Mon Sep 17 00:00:00 2001 From: Arun-Niranjan Date: Sat, 30 Dec 2023 16:56:01 +0000 Subject: [PATCH 11/12] Refactor get model dimensions from labels --- .gitignore | 3 +- pymdp/utils.py | 69 +++++++++++++++++++++++++++++----------------- test/test_utils.py | 58 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 27 deletions(-) create mode 100644 test/test_utils.py diff --git a/.gitignore b/.gitignore index 778d69dd..5f24acf9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ __pycache__ .ipynb_checkpoints/ .pytest_cache env/ -pymdp.egg-info \ No newline at end of file +pymdp.egg-info +inferactively_pymdp.egg-info diff --git a/pymdp/utils.py b/pymdp/utils.py index 60bc41e0..bd985cf5 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -16,6 +16,27 @@ EPS_VAL = 1e-16 # global constant for use in norm_dist() +class Dimensions(object): + """ + The Dimensions class stores all data related to the size and shape of a model. + """ + def __init__( + self, + num_observations=None, + num_observation_modalities=0, + num_states=None, + num_state_factors=0, + num_controls=None, + num_control_factors=0, + ): + self.num_observations=num_observations + self.num_observation_modalities=num_observation_modalities + self.num_states=num_states + self.num_state_factors=num_state_factors + self.num_controls=num_controls + self.num_control_factors=num_control_factors + + def sample(probabilities): probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities sample_onehot = np.random.multinomial(1, probabilities) @@ -211,22 +232,22 @@ def get_model_dimensions(A=None, B=None, factorized=False): def get_model_dimensions_from_labels(model_labels): modalities = model_labels['observations'] - num_modalities = len(modalities.keys()) - num_obs = [len(modalities[modality]) for modality in modalities.keys()] - factors = model_labels['states'] - num_factors = len(factors.keys()) - num_states = [len(factors[factor]) for factor in factors.keys()] - if 'actions' in model_labels.keys(): + res = Dimensions( + num_observations=[len(modalities[modality]) for modality in modalities.keys()], + num_observation_modalities=len(modalities.keys()), + num_states=[len(factors[factor]) for factor in factors.keys()], + num_state_factors=len(factors.keys()), + ) + if 'actions' in model_labels.keys(): controls = model_labels['actions'] - num_control_fac = len(controls.keys()) - num_controls = [len(controls[cfac]) for cfac in controls.keys()] + res.num_controls=[len(controls[cfac]) for cfac in controls.keys()] + res.num_control_factors=len(controls.keys()) + + return res - return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac - else: - return num_obs, num_modalities, num_states, num_factors def norm_dist(dist): @@ -464,21 +485,18 @@ def construct_full_a(A_reduced, original_factor_idx, num_states): def create_A_matrix_stub(model_labels): - num_obs, _, num_states, _= get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) obs_labels, state_labels = model_labels['observations'], model_labels['states'] state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys())) - num_state_combos = np.prod(num_states) - # num_rows = (np.array(num_obs) * num_state_combos).sum() - num_rows = sum(num_obs) + num_rows = sum(dimensions.num_observations) cell_values = np.zeros((num_rows, len(state_combinations))) obs_combinations = [] for modality in obs_labels.keys(): levels_to_combine = [[modality]] + [obs_labels[modality]] - # obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine)) obs_combinations += list(itertools.product(*levels_to_combine)) @@ -490,7 +508,7 @@ def create_A_matrix_stub(model_labels): def create_B_matrix_stubs(model_labels): - _, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) state_labels = model_labels['states'] action_labels = model_labels['actions'] @@ -504,9 +522,9 @@ def create_B_matrix_stubs(model_labels): prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]]) - num_state_action_combos = num_states[f_idx] * num_controls[f_idx] + num_state_action_combos = dimensions.num_states[f_idx] * dimensions.num_controls[f_idx] - num_rows = num_states[f_idx] + num_rows = dimensions.num_states[f_idx] cell_values = np.zeros((num_rows, num_state_action_combos)) @@ -559,13 +577,12 @@ def convert_A_stub_to_ndarray(A_stub, model_labels): This function converts a multi-index pandas dataframe `A_stub` into an object array of different A matrices, one per observation modality. """ + dimensions = get_model_dimensions_from_labels(model_labels) - num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels) - - A = obj_array(num_modalities) + A = obj_array(dimensions.num_observation_modalities) for g, modality_name in enumerate(model_labels['observations'].keys()): - A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states) + A[g] = A_stub.loc[modality_name].to_numpy().reshape(dimensions.num_observations[g], *dimensions.num_states) assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n' return A @@ -576,13 +593,13 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels): of different B matrices, one per hidden state factor """ - _, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) - B = obj_array(num_factors) + B = obj_array(dimensions.num_control_factors) for f, factor_name in enumerate(B_stubs.keys()): - B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f]) + B[f] = B_stubs[factor_name].to_numpy().reshape(dimensions.num_states[f], dimensions.num_states[f], dimensions.num_controls[f]) assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n' return B diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..0a1ed066 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,58 @@ + +import unittest + +from pymdp.utils import get_model_dimensions_from_labels, Dimensions + +class TestUtils(unittest.TestCase): + def test_get_model_dimensions_from_labels(self): + """ + Tests model dimension extraction from labels including observations, states and actions. + """ + model_labels = { + "observations": { + "species_observation": [ + "absent", + "present", + ], + "budget_observation": [ + "high", + "medium", + "low", + ], + }, + "states": { + "species_state": [ + "extant", + "extinct", + ], + }, + "actions": { + "conservation_action": [ + "manage", + "survey", + "stop", + ], + }, + } + + want = Dimensions( + num_observations=[2, 3], + num_observation_modalities=2, + num_states=[2], + num_state_factors=1, + num_controls=[3], + num_control_factors=1, + ) + + got = get_model_dimensions_from_labels(model_labels) + + self.assertEqual(want.num_observations, got.num_observations) + self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) + self.assertEqual(want.num_states, got.num_states) + self.assertEqual(want.num_state_factors, got.num_state_factors) + self.assertEqual(want.num_controls, got.num_controls) + self.assertEqual(want.num_control_factors, got.num_control_factors) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 1b5071083154a4235bf89abdd14378ce44bb5d94 Mon Sep 17 00:00:00 2001 From: Arun-Niranjan Date: Sat, 30 Dec 2023 17:03:50 +0000 Subject: [PATCH 12/12] Move unit test and update existing test for creating A matrix stub --- test/test_utils.py | 58 ---------------------------------------- test/test_wrappers.py | 62 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 64 deletions(-) delete mode 100644 test/test_utils.py diff --git a/test/test_utils.py b/test/test_utils.py deleted file mode 100644 index 0a1ed066..00000000 --- a/test/test_utils.py +++ /dev/null @@ -1,58 +0,0 @@ - -import unittest - -from pymdp.utils import get_model_dimensions_from_labels, Dimensions - -class TestUtils(unittest.TestCase): - def test_get_model_dimensions_from_labels(self): - """ - Tests model dimension extraction from labels including observations, states and actions. - """ - model_labels = { - "observations": { - "species_observation": [ - "absent", - "present", - ], - "budget_observation": [ - "high", - "medium", - "low", - ], - }, - "states": { - "species_state": [ - "extant", - "extinct", - ], - }, - "actions": { - "conservation_action": [ - "manage", - "survey", - "stop", - ], - }, - } - - want = Dimensions( - num_observations=[2, 3], - num_observation_modalities=2, - num_states=[2], - num_state_factors=1, - num_controls=[3], - num_control_factors=1, - ) - - got = get_model_dimensions_from_labels(model_labels) - - self.assertEqual(want.num_observations, got.num_observations) - self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) - self.assertEqual(want.num_states, got.num_states) - self.assertEqual(want.num_state_factors, got.num_state_factors) - self.assertEqual(want.num_controls, got.num_controls) - self.assertEqual(want.num_control_factors, got.num_control_factors) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/test/test_wrappers.py b/test/test_wrappers.py index db25c984..254d902b 100644 --- a/test/test_wrappers.py +++ b/test/test_wrappers.py @@ -1,15 +1,11 @@ import os import unittest from pathlib import Path -import shutil -import tempfile -import numpy as np -import itertools import pandas as pd from pandas.testing import assert_frame_equal -from pymdp.utils import create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices +from pymdp.utils import Dimensions, get_model_dimensions_from_labels, create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices tmp_path = Path('tmp_dir') @@ -18,11 +14,62 @@ class TestWrappers(unittest.TestCase): + def test_get_model_dimensions_from_labels(self): + """ + Tests model dimension extraction from labels including observations, states and actions. + """ + model_labels = { + "observations": { + "species_observation": [ + "absent", + "present", + ], + "budget_observation": [ + "high", + "medium", + "low", + ], + }, + "states": { + "species_state": [ + "extant", + "extinct", + ], + }, + "actions": { + "conservation_action": [ + "manage", + "survey", + "stop", + ], + }, + } + + want = Dimensions( + num_observations=[2, 3], + num_observation_modalities=2, + num_states=[2], + num_state_factors=1, + num_controls=[3], + num_control_factors=1, + ) + + got = get_model_dimensions_from_labels(model_labels) + + self.assertEqual(want.num_observations, got.num_observations) + self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) + self.assertEqual(want.num_states, got.num_states) + self.assertEqual(want.num_state_factors, got.num_state_factors) + self.assertEqual(want.num_controls, got.num_controls) + self.assertEqual(want.num_control_factors, got.num_control_factors) + def test_A_matrix_stub(self): """ This tests the construction of a 2-modality, 2-hidden state factor pandas MultiIndex dataframe using the `model_labels` dictionary, which contains the modality- and factor-specific levels, labeled with string - identifiers + identifiers. + + Note: actions are ignored when creating an A matrix stub """ model_labels = { @@ -41,6 +88,9 @@ def test_A_matrix_stub(self): "weather_state": ["raining", "clear"], "sprinkler_state": ["on", "off"], }, + "actions": { + "actions": ["something", "nothing"], + } } num_hidden_state_factors = len(model_labels["states"])