Skip to content

Commit

Permalink
unit tests in test_agent.py and test_control.py for new version o…
Browse files Browse the repository at this point in the history
…f `utils.sample()` within `control.sample_action`and `self.sample_action()` method of `Agent`, when action_selection is "stochastic"
  • Loading branch information
conorheins committed Jul 27, 2022
1 parent 584caf5 commit 644af07
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,32 @@ def test_agent_with_sampling_mode(self):
agent.infer_policies()
chosen_action, p_actions = agent._sample_action_test()
self.assertEqual(len(p_actions[0]), num_controls[0])

def test_agent_with_stochastic_action_unidimensional_control(self):
"""
Test stochastic action sampling in case that one of the control states is one-dimensional, within the agent
method `sample_action()`.
Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an
error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array)
"""

num_obs = [2]
num_states = [2, 2]
num_controls = [2, 1]

A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, action_selection = "stochastic")
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)

agent = Agent(A=A, B=B, action_selection = "deterministic")
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)




Expand Down
17 changes: 17 additions & 0 deletions test/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,23 @@ def test_update_posterior_policies_withE_vector(self):

self.assertGreater(q_pi[0], q_pi[1])
self.assertGreater(q_pi[2], q_pi[1])

def test_stochastic_action_unidimensional_control(self):
"""
Test stochastic action sampling in case that one of the control states is one-dimensional.
Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an
error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array)
"""

num_states = [2, 2]
num_controls = [2, 1]
policies = control.construct_policies(num_states, num_controls = num_controls, policy_len=1)
q_pi = utils.norm_dist(np.random.rand(len(policies)))
sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="stochastic")
self.assertEqual(sampled_action[1], 0)

sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="deterministic")
self.assertEqual(sampled_action[1], 0)


if __name__ == "__main__":
Expand Down

0 comments on commit 644af07

Please sign in to comment.