Skip to content

Commit

Permalink
Merge pull request #90 from infer-actively/action_sampling_bug_86
Browse files Browse the repository at this point in the history
Fix action sampling bug when there's a uni-dimensional control state (to address Issue #86)
  • Loading branch information
conorheins authored Jul 27, 2022
2 parents 21b0a52 + 644af07 commit a9fe7bd
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
EPS_VAL = 1e-16 # global constant for use in norm_dist()

def sample(probabilities):
sample_onehot = np.random.multinomial(1, probabilities.squeeze())
probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities
sample_onehot = np.random.multinomial(1, probabilities)
return np.where(sample_onehot == 1)[0][0]

def sample_obj_array(arr):
Expand Down
26 changes: 26 additions & 0 deletions test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,32 @@ def test_agent_with_sampling_mode(self):
agent.infer_policies()
chosen_action, p_actions = agent._sample_action_test()
self.assertEqual(len(p_actions[0]), num_controls[0])

def test_agent_with_stochastic_action_unidimensional_control(self):
"""
Test stochastic action sampling in case that one of the control states is one-dimensional, within the agent
method `sample_action()`.
Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an
error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array)
"""

num_obs = [2]
num_states = [2, 2]
num_controls = [2, 1]

A = utils.random_A_matrix(num_obs, num_states)
B = utils.random_B_matrix(num_states, num_controls)

agent = Agent(A=A, B=B, action_selection = "stochastic")
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)

agent = Agent(A=A, B=B, action_selection = "deterministic")
agent.infer_policies()
chosen_action = agent.sample_action()
self.assertEqual(chosen_action[1], 0)




Expand Down
17 changes: 17 additions & 0 deletions test/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,23 @@ def test_update_posterior_policies_withE_vector(self):

self.assertGreater(q_pi[0], q_pi[1])
self.assertGreater(q_pi[2], q_pi[1])

def test_stochastic_action_unidimensional_control(self):
"""
Test stochastic action sampling in case that one of the control states is one-dimensional.
Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an
error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array)
"""

num_states = [2, 2]
num_controls = [2, 1]
policies = control.construct_policies(num_states, num_controls = num_controls, policy_len=1)
q_pi = utils.norm_dist(np.random.rand(len(policies)))
sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="stochastic")
self.assertEqual(sampled_action[1], 0)

sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="deterministic")
self.assertEqual(sampled_action[1], 0)


if __name__ == "__main__":
Expand Down

0 comments on commit a9fe7bd

Please sign in to comment.