diff --git a/pymdp/utils.py b/pymdp/utils.py index b3ca83dc..18a5a69e 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -15,7 +15,8 @@ EPS_VAL = 1e-16 # global constant for use in norm_dist() def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) + probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities + sample_onehot = np.random.multinomial(1, probabilities) return np.where(sample_onehot == 1)[0][0] def sample_obj_array(arr):