Skip to content

Commit

Permalink
Merge pull request #104 from infer-actively/utils_helpers
Browse files Browse the repository at this point in the history
Utils helpers
  • Loading branch information
conorheins authored Dec 8, 2022
2 parents 016c93a + cdbc95b commit 41d5b20
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 34 deletions.
6 changes: 3 additions & 3 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
control_fac_idx=None,
policies=None,
gamma=16.0,
alpha = 16.0,
alpha=16.0,
use_utility=True,
use_states_info_gain=True,
use_param_info_gain=False,
Expand Down Expand Up @@ -394,7 +394,7 @@ def get_future_qs(self):
return future_qs_seq


def infer_states(self, observation):
def infer_states(self, observation, distr_obs = False):
"""
Update approximate posterior over hidden states by solving variational inference problem, given an observation.
Expand All @@ -414,7 +414,7 @@ def infer_states(self, observation):
at timepoint ``t_idx``.
"""

observation = tuple(observation)
observation = tuple(observation) if not distr_obs else observation

if not hasattr(self, "qs"):
self.reset()
Expand Down
76 changes: 47 additions & 29 deletions pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ def obj_array_zeros(shape_list):
arr[i] = np.zeros(shape)
return arr

def initialize_empty_A(num_obs, num_states):
"""
Initializes an empty observation likelihood array or `A` array using a list of observation-modality dimensions (`num_obs`)
and hidden state factor dimensions (`num_states`)
"""

A_shape_list = [ [no] + num_states for no in num_obs]
return obj_array_zeros(A_shape_list)

def initialize_empty_B(num_states, num_controls):
"""
Initializes an empty (controllable) transition likelihood array or `B` array using a list of hidden state factor dimensions (`num_states`)
and control factor dimensions (`num_controls)
"""

B_shape_list = [ [ns, ns, num_controls[f]] for f, ns in enumerate(num_states)]
return obj_array_zeros(B_shape_list)

def obj_array_uniform(shape_list):
"""
Creates a numpy object array whose sub-arrays are uniform Categorical
Expand Down Expand Up @@ -559,34 +577,34 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels):

return B

def build_belief_array(qx):

"""
This function constructs array-ified (not nested) versions
of the posterior belief arrays, that are separated
by policy, timepoint, and hidden state factor
"""

num_policies = len(qx)
num_timesteps = len(qx[0])
num_factors = len(qx[0][0])

if num_factors > 1:
belief_array = utils.obj_array(num_factors)
for factor in range(num_factors):
belief_array[factor] = np.zeros( (num_policies, qx[0][0][factor].shape[0], num_timesteps) )
for policy_i in range(num_policies):
for timestep in range(num_timesteps):
for factor in range(num_factors):
belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor]
else:
num_states = qx[0][0][0].shape[0]
belief_array = np.zeros( (num_policies, num_states, num_timesteps) )
for policy_i in range(num_policies):
for timestep in range(num_timesteps):
belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0]
# def build_belief_array(qx):

# """
# This function constructs array-ified (not nested) versions
# of the posterior belief arrays, that are separated
# by policy, timepoint, and hidden state factor
# """

# num_policies = len(qx)
# num_timesteps = len(qx[0])
# num_factors = len(qx[0][0])

# if num_factors > 1:
# belief_array = obj_array(num_factors)
# for factor in range(num_factors):
# belief_array[factor] = np.zeros( (num_policies, qx[0][0][factor].shape[0], num_timesteps) )
# for policy_i in range(num_policies):
# for timestep in range(num_timesteps):
# for factor in range(num_factors):
# belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor]
# else:
# num_states = qx[0][0][0].shape[0]
# belief_array = np.zeros( (num_policies, num_states, num_timesteps) )
# for policy_i in range(num_policies):
# for timestep in range(num_timesteps):
# belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0]

return belief_array
# return belief_array

def build_xn_vn_array(xn):

Expand All @@ -601,9 +619,9 @@ def build_xn_vn_array(xn):
num_factors = len(xn[0][0])

if num_factors > 1:
xn_array = utils.obj_array(num_factors)
xn_array = obj_array(num_factors)
for factor in range(num_factors):
num_states, infer_len = xn[0][0][f].shape
num_states, infer_len = xn[0][0][factor].shape
xn_array[factor] = np.zeros( (num_itr, num_states, infer_len, num_policies) )
for policy_i in range(num_policies):
for itr in range(num_itr):
Expand Down
2 changes: 1 addition & 1 deletion test/test_SPM_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.io import loadmat

from pymdp.agent import Agent
from pymdp.utils import to_obj_array, build_belief_array, build_xn_vn_array, get_model_dimensions, convert_observation_array
from pymdp.utils import to_obj_array, build_xn_vn_array, get_model_dimensions, convert_observation_array
from pymdp.maths import dirichlet_log_evidence

DATA_PATH = "test/matlab_crossval/output/"
Expand Down
85 changes: 84 additions & 1 deletion test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_agent_infer_states(self):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities '''
''' Marginal message passing inference with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
Expand Down Expand Up @@ -494,6 +494,89 @@ def test_agent_with_stochastic_action_unidimensional_control(self):
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)

def test_agent_distributional_obs(self):

''' VANILLA method (fixed point iteration) with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "VANILLA")

p_o = utils.obj_array_zeros(num_obs) # use a distributional observation
# @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially
p_o[0] = A[0][:,0]
qs_out = agent.infer_states(p_o, distr_obs=True)

qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D)

for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' VANILLA method (fixed point iteration) with multiple hidden state factors and multiple observation modalities '''
num_obs = [2, 4]
num_states = [2, 3]
num_controls = [2, 3]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "VANILLA")

p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation
qs_out = agent.infer_states(p_o, distr_obs=True)

qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D)

for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' Marginal message passing inference with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "MMP")

p_o = utils.obj_array_zeros(num_obs) # use a distributional observation
# @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially
p_o[0] = A[0][:,0]
qs_pi_out = agent.infer_states(p_o, distr_obs=True)

policies = control.construct_policies(num_states, num_controls, policy_len = 1)

qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False)

for p_idx in range(len(policies)):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_pi_validation[p_idx][0][f], qs_pi_out[p_idx][0][f]).all())

''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities '''
num_obs = [2, 4]
num_states = [2, 2]
num_controls = [2, 2]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

planning_horizon = 3
backwards_horizon = 1
agent = Agent(A=A, B=B, inference_algo="MMP", policy_len=planning_horizon, inference_horizon=backwards_horizon)
p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation
qs_pi_out = agent.infer_states(p_o, distr_obs=True)

policies = control.construct_policies(num_states, num_controls, policy_len = planning_horizon)

qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False)

for p_idx in range(len(policies)):
for t in range(planning_horizon+backwards_horizon):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_pi_validation[p_idx][t][f], qs_pi_out[p_idx][t][f]).all())




Expand Down

0 comments on commit 41d5b20

Please sign in to comment.