diff --git a/pymdp/control.py b/pymdp/control.py index 307e72de..80bfd16b 100644 --- a/pymdp/control.py +++ b/pymdp/control.py @@ -492,9 +492,9 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic" # weight each action according to its integrated posterior probability over policies and timesteps for pol_idx, policy in enumerate(policies): - for t in range(policy.shape[0]): - for factor_i, action_i in enumerate(policy[t, :]): - action_marginals[factor_i][action_i] += q_pi[pol_idx] + for factor_i, action_i in enumerate(policy[0, :]): + # to get the marginals we just want to add up the actions at time 0 + action_marginals[factor_i][action_i] += q_pi[pol_idx] selected_policy = np.zeros(num_factors) for factor_i in range(num_factors): diff --git a/pymdp/utils.py b/pymdp/utils.py index 7d823082..1a0b9ea0 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -13,7 +13,7 @@ import itertools def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) + sample_onehot = np.random.multinomial(1, probabilities) return np.where(sample_onehot == 1)[0][0] def sample_obj_array(arr):