Skip to content

Commit

Permalink
Conditionally .squeeze() probability distribution, handles cases of…
Browse files Browse the repository at this point in the history
… 1-D probability distributions
  • Loading branch information
conorheins committed Jul 27, 2022
1 parent 21b0a52 commit 584caf5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
EPS_VAL = 1e-16 # global constant for use in norm_dist()

def sample(probabilities):
sample_onehot = np.random.multinomial(1, probabilities.squeeze())
probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities
sample_onehot = np.random.multinomial(1, probabilities)
return np.where(sample_onehot == 1)[0][0]

def sample_obj_array(arr):
Expand Down

0 comments on commit 584caf5

Please sign in to comment.