Skip to content

Commit

Permalink
Added unit tests to test_agent.py to test out new distributional obse…
Browse files Browse the repository at this point in the history
…rvation functionality
  • Loading branch information
conorheins committed Dec 8, 2022
1 parent 3bc77e1 commit cdbc95b
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_agent_infer_states(self):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities '''
''' Marginal message passing inference with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
Expand Down Expand Up @@ -494,6 +494,89 @@ def test_agent_with_stochastic_action_unidimensional_control(self):
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)

def test_agent_distributional_obs(self):

''' VANILLA method (fixed point iteration) with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "VANILLA")

p_o = utils.obj_array_zeros(num_obs) # use a distributional observation
# @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially
p_o[0] = A[0][:,0]
qs_out = agent.infer_states(p_o, distr_obs=True)

qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D)

for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' VANILLA method (fixed point iteration) with multiple hidden state factors and multiple observation modalities '''
num_obs = [2, 4]
num_states = [2, 3]
num_controls = [2, 3]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "VANILLA")

p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation
qs_out = agent.infer_states(p_o, distr_obs=True)

qs_validation = inference.update_posterior_states(A, p_o, prior=agent.D)

for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_validation[f], qs_out[f]).all())

''' Marginal message passing inference with one hidden state factor and one observation modality '''
num_obs = [5]
num_states = [3]
num_controls = [1]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, inference_algo = "MMP")

p_o = utils.obj_array_zeros(num_obs) # use a distributional observation
# @NOTE: `utils.obj_array_from_list` will make a nested list of object arrays if you only put in a list with one vector!!! Makes me think we should remove utils.obj_array_from_list potentially
p_o[0] = A[0][:,0]
qs_pi_out = agent.infer_states(p_o, distr_obs=True)

policies = control.construct_policies(num_states, num_controls, policy_len = 1)

qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False)

for p_idx in range(len(policies)):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_pi_validation[p_idx][0][f], qs_pi_out[p_idx][0][f]).all())

''' Marginal message passing inference with multiple hidden state factors and multiple observation modalities '''
num_obs = [2, 4]
num_states = [2, 2]
num_controls = [2, 2]
A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

planning_horizon = 3
backwards_horizon = 1
agent = Agent(A=A, B=B, inference_algo="MMP", policy_len=planning_horizon, inference_horizon=backwards_horizon)
p_o = utils.obj_array_from_list([A[0][:,0,0], A[1][:,1,1]]) # use a distributional observation
qs_pi_out = agent.infer_states(p_o, distr_obs=True)

policies = control.construct_policies(num_states, num_controls, policy_len = planning_horizon)

qs_pi_validation, _ = inference.update_posterior_states_full(A, B, [p_o], policies, prior = agent.D, policy_sep_prior = False)

for p_idx in range(len(policies)):
for t in range(planning_horizon+backwards_horizon):
for f in range(len(num_states)):
self.assertTrue(np.isclose(qs_pi_validation[p_idx][t][f], qs_pi_out[p_idx][t][f]).all())




Expand Down

0 comments on commit cdbc95b

Please sign in to comment.