diff --git a/pymdp/utils.py b/pymdp/utils.py index b3ca83dc..18a5a69e 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -15,7 +15,8 @@ EPS_VAL = 1e-16 # global constant for use in norm_dist() def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) + probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities + sample_onehot = np.random.multinomial(1, probabilities) return np.where(sample_onehot == 1)[0][0] def sample_obj_array(arr): diff --git a/test/test_agent.py b/test/test_agent.py index a120433a..4d3b94a1 100644 --- a/test/test_agent.py +++ b/test/test_agent.py @@ -469,6 +469,32 @@ def test_agent_with_sampling_mode(self): agent.infer_policies() chosen_action, p_actions = agent._sample_action_test() self.assertEqual(len(p_actions[0]), num_controls[0]) + + def test_agent_with_stochastic_action_unidimensional_control(self): + """ + Test stochastic action sampling in case that one of the control states is one-dimensional, within the agent + method `sample_action()`. + Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an + error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array) + """ + + num_obs = [2] + num_states = [2, 2] + num_controls = [2, 1] + + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + agent = Agent(A=A, B=B, action_selection = "stochastic") + agent.infer_policies() + chosen_action = agent.sample_action() + self.assertEqual(chosen_action[1], 0) + + agent = Agent(A=A, B=B, action_selection = "deterministic") + agent.infer_policies() + chosen_action = agent.sample_action() + self.assertEqual(chosen_action[1], 0) + diff --git a/test/test_control.py b/test/test_control.py index 552fb4be..54423496 100644 --- a/test/test_control.py +++ b/test/test_control.py @@ -1375,6 +1375,23 @@ def test_update_posterior_policies_withE_vector(self): self.assertGreater(q_pi[0], q_pi[1]) self.assertGreater(q_pi[2], q_pi[1]) + + def test_stochastic_action_unidimensional_control(self): + """ + Test stochastic action sampling in case that one of the control states is one-dimensional. + Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an + error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array) + """ + + num_states = [2, 2] + num_controls = [2, 1] + policies = control.construct_policies(num_states, num_controls = num_controls, policy_len=1) + q_pi = utils.norm_dist(np.random.rand(len(policies))) + sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="stochastic") + self.assertEqual(sampled_action[1], 0) + + sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="deterministic") + self.assertEqual(sampled_action[1], 0) if __name__ == "__main__":