From 584caf537640e816e7df24e6f861f6e00d7a610b Mon Sep 17 00:00:00 2001 From: conorheins Date: Wed, 27 Jul 2022 13:09:43 +0200 Subject: [PATCH 1/2] Conditionally `.squeeze()` probability distribution, handles cases of 1-D probability distributions --- pymdp/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymdp/utils.py b/pymdp/utils.py index b3ca83dc..18a5a69e 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -15,7 +15,8 @@ EPS_VAL = 1e-16 # global constant for use in norm_dist() def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) + probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities + sample_onehot = np.random.multinomial(1, probabilities) return np.where(sample_onehot == 1)[0][0] def sample_obj_array(arr): From 644af0731e596985b77c4c629fd4097528467334 Mon Sep 17 00:00:00 2001 From: conorheins Date: Wed, 27 Jul 2022 13:10:39 +0200 Subject: [PATCH 2/2] unit tests in `test_agent.py` and `test_control.py` for new version of `utils.sample()` within `control.sample_action`and `self.sample_action()` method of `Agent`, when action_selection is "stochastic" --- test/test_agent.py | 26 ++++++++++++++++++++++++++ test/test_control.py | 17 +++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/test/test_agent.py b/test/test_agent.py index a120433a..4d3b94a1 100644 --- a/test/test_agent.py +++ b/test/test_agent.py @@ -469,6 +469,32 @@ def test_agent_with_sampling_mode(self): agent.infer_policies() chosen_action, p_actions = agent._sample_action_test() self.assertEqual(len(p_actions[0]), num_controls[0]) + + def test_agent_with_stochastic_action_unidimensional_control(self): + """ + Test stochastic action sampling in case that one of the control states is one-dimensional, within the agent + method `sample_action()`. + Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an + error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array) + """ + + num_obs = [2] + num_states = [2, 2] + num_controls = [2, 1] + + A = utils.random_A_matrix(num_obs, num_states) + B = utils.random_B_matrix(num_states, num_controls) + + agent = Agent(A=A, B=B, action_selection = "stochastic") + agent.infer_policies() + chosen_action = agent.sample_action() + self.assertEqual(chosen_action[1], 0) + + agent = Agent(A=A, B=B, action_selection = "deterministic") + agent.infer_policies() + chosen_action = agent.sample_action() + self.assertEqual(chosen_action[1], 0) + diff --git a/test/test_control.py b/test/test_control.py index 552fb4be..54423496 100644 --- a/test/test_control.py +++ b/test/test_control.py @@ -1375,6 +1375,23 @@ def test_update_posterior_policies_withE_vector(self): self.assertGreater(q_pi[0], q_pi[1]) self.assertGreater(q_pi[2], q_pi[1]) + + def test_stochastic_action_unidimensional_control(self): + """ + Test stochastic action sampling in case that one of the control states is one-dimensional. + Due to a call to probabilities.squeeze() in an earlier version of utils.sample(), this was throwing an + error due to the inability to use np.random.multinomial on an array with undefined length (an 'unsized' array) + """ + + num_states = [2, 2] + num_controls = [2, 1] + policies = control.construct_policies(num_states, num_controls = num_controls, policy_len=1) + q_pi = utils.norm_dist(np.random.rand(len(policies))) + sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="stochastic") + self.assertEqual(sampled_action[1], 0) + + sampled_action = control.sample_action(q_pi, policies, num_controls, action_selection="deterministic") + self.assertEqual(sampled_action[1], 0) if __name__ == "__main__":