diff --git a/pymdp/agent.py b/pymdp/agent.py index 2bec1a16..94d093e2 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -46,7 +46,7 @@ def __init__( control_fac_idx=None, policies=None, gamma=16.0, - alpha = 16.0, + alpha=16.0, use_utility=True, use_states_info_gain=True, use_param_info_gain=False, @@ -394,7 +394,7 @@ def get_future_qs(self): return future_qs_seq - def infer_states(self, observation): + def infer_states(self, observation, distr_obs = False): """ Update approximate posterior over hidden states by solving variational inference problem, given an observation. @@ -414,7 +414,7 @@ def infer_states(self, observation): at timepoint ``t_idx``. """ - observation = tuple(observation) + observation = tuple(observation) if not distr_obs else observation if not hasattr(self, "qs"): self.reset() diff --git a/pymdp/utils.py b/pymdp/utils.py index 93e75cfd..3934f09b 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -46,6 +46,24 @@ def obj_array_zeros(shape_list): arr[i] = np.zeros(shape) return arr +def initialize_empty_A(num_obs, num_states): + """ + Initializes an empty observation likelihood array or `A` array using a list of observation-modality dimensions (`num_obs`) + and hidden state factor dimensions (`num_states`) + """ + + A_shape_list = [ [no] + num_states for no in num_obs] + return obj_array_zeros(A_shape_list) + +def initialize_empty_B(num_states, num_controls): + """ + Initializes an empty (controllable) transition likelihood array or `B` array using a list of hidden state factor dimensions (`num_states`) + and control factor dimensions (`num_controls) + """ + + B_shape_list = [ [ns, ns, num_controls[f]] for f, ns in enumerate(num_states)] + return obj_array_zeros(B_shape_list) + def obj_array_uniform(shape_list): """ Creates a numpy object array whose sub-arrays are uniform Categorical @@ -559,34 +577,34 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels): return B -def build_belief_array(qx): - - """ - This function constructs array-ified (not nested) versions - of the posterior belief arrays, that are separated - by policy, timepoint, and hidden state factor - """ - - num_policies = len(qx) - num_timesteps = len(qx[0]) - num_factors = len(qx[0][0]) - - if num_factors > 1: - belief_array = utils.obj_array(num_factors) - for factor in range(num_factors): - belief_array[factor] = np.zeros( (num_policies, qx[0][0][factor].shape[0], num_timesteps) ) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - for factor in range(num_factors): - belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor] - else: - num_states = qx[0][0][0].shape[0] - belief_array = np.zeros( (num_policies, num_states, num_timesteps) ) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0] +# def build_belief_array(qx): + +# """ +# This function constructs array-ified (not nested) versions +# of the posterior belief arrays, that are separated +# by policy, timepoint, and hidden state factor +# """ + +# num_policies = len(qx) +# num_timesteps = len(qx[0]) +# num_factors = len(qx[0][0]) + +# if num_factors > 1: +# belief_array = obj_array(num_factors) +# for factor in range(num_factors): +# belief_array[factor] = np.zeros( (num_policies, qx[0][0][factor].shape[0], num_timesteps) ) +# for policy_i in range(num_policies): +# for timestep in range(num_timesteps): +# for factor in range(num_factors): +# belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor] +# else: +# num_states = qx[0][0][0].shape[0] +# belief_array = np.zeros( (num_policies, num_states, num_timesteps) ) +# for policy_i in range(num_policies): +# for timestep in range(num_timesteps): +# belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0] - return belief_array +# return belief_array def build_xn_vn_array(xn): @@ -601,9 +619,9 @@ def build_xn_vn_array(xn): num_factors = len(xn[0][0]) if num_factors > 1: - xn_array = utils.obj_array(num_factors) + xn_array = obj_array(num_factors) for factor in range(num_factors): - num_states, infer_len = xn[0][0][f].shape + num_states, infer_len = xn[0][0][factor].shape xn_array[factor] = np.zeros( (num_itr, num_states, infer_len, num_policies) ) for policy_i in range(num_policies): for itr in range(num_itr): diff --git a/test/test_SPM_validation.py b/test/test_SPM_validation.py index 537cab2c..8c0bc136 100644 --- a/test/test_SPM_validation.py +++ b/test/test_SPM_validation.py @@ -5,7 +5,7 @@ from scipy.io import loadmat from pymdp.agent import Agent -from pymdp.utils import to_obj_array, build_belief_array, build_xn_vn_array, get_model_dimensions, convert_observation_array +from pymdp.utils import to_obj_array, build_xn_vn_array, get_model_dimensions, convert_observation_array from pymdp.maths import dirichlet_log_evidence DATA_PATH = "test/matlab_crossval/output/" diff --git a/test/test_agent.py b/test/test_agent.py index 4d3b94a1..ad3768ca 100644 --- a/test/test_agent.py +++ b/test/test_agent.py @@ -129,7 +129,7 @@ def test_agent_infer_states(self): for f in range(len(num_states)): self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all()) - ''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities ''' + ''' Marginal message passing inference with one hidden state factor and one observation modality ''' num_obs = [5] num_states = [3] num_controls = [1] @@ -494,6 +494,89 @@ def test_agent_with_stochastic_action_unidimensional_control(self): agent.infer_policies() chosen_action = agent.sample_action() self.assertEqual(chosen_action[1], 0) + + def test_agent_distributional_obs(self): + + ''' VANILLA method (fixed point iteration) with one hidden state factor and one observation modality ''' + num_obs = [5] + num_states = [3] + num_controls = [1] + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + agent = Agent(A=A, B=B, inference_algo = "VANILLA") + + p_o = utils.obj_array_zeros(num_obs) # use a distributional observation + # @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially + p_o[0] = A[0][:,0] + qs_out = agent.infer_states(p_o, distr_obs=True) + + qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D) + + for f in range(len(num_states)): + self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all()) + + ''' VANILLA method (fixed point iteration) with multiple hidden state factors and multiple observation modalities ''' + num_obs = [2, 4] + num_states = [2, 3] + num_controls = [2, 3] + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + agent = Agent(A=A, B=B, inference_algo = "VANILLA") + + p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation + qs_out = agent.infer_states(p_o, distr_obs=True) + + qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D) + + for f in range(len(num_states)): + self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all()) + + ''' Marginal message passing inference with one hidden state factor and one observation modality ''' + num_obs = [5] + num_states = [3] + num_controls = [1] + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + agent = Agent(A=A, B=B, inference_algo = "MMP") + + p_o = utils.obj_array_zeros(num_obs) # use a distributional observation + # @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially + p_o[0] = A[0][:,0] + qs_pi_out = agent.infer_states(p_o, distr_obs=True) + + policies = control.construct_policies(num_states, num_controls, policy_len = 1) + + qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False) + + for p_idx in range(len(policies)): + for f in range(len(num_states)): + self.assertTrue(np.isclose(qs_pi_validation[p_idx][0][f], qs_pi_out[p_idx][0][f]).all()) + + ''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities ''' + num_obs = [2, 4] + num_states = [2, 2] + num_controls = [2, 2] + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + planning_horizon = 3 + backwards_horizon = 1 + agent = Agent(A=A, B=B, inference_algo="MMP", policy_len=planning_horizon, inference_horizon=backwards_horizon) + p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation + qs_pi_out = agent.infer_states(p_o, distr_obs=True) + + policies = control.construct_policies(num_states, num_controls, policy_len = planning_horizon) + + qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False) + + for p_idx in range(len(policies)): + for t in range(planning_horizon+backwards_horizon): + for f in range(len(num_states)): + self.assertTrue(np.isclose(qs_pi_validation[p_idx][t][f], qs_pi_out[p_idx][t][f]).all()) +