Skip to content

Commit

Permalink
Merge pull request #122 from infer-actively/norm_dist_fix_121
Browse files Browse the repository at this point in the history
Address #121 by simplifying `utils.norm_dist`
  • Loading branch information
conorheins authored Jun 23, 2023
2 parents b960ef3 + f2ce8f6 commit 213a6ef
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,11 @@ def get_model_dimensions_from_labels(model_labels):

def norm_dist(dist):
""" Normalizes a Categorical probability distribution (or set of them) assuming sufficient statistics are stored in leading dimension"""
if dist.ndim == 3:
new_dist = np.zeros_like(dist)
for c in range(dist.shape[2]):
new_dist[:, :, c] = np.divide(dist[:, :, c], dist[:, :, c].sum(axis=0))
return new_dist
else:
return np.divide(dist, dist.sum(axis=0))
return np.divide(dist, dist.sum(axis=0))

def norm_dist_obj_arr(obj_arr):

""" Normalizes a multi-factor or -modality collection of Categorical probability distributions, assuming sufficient statistics of each conditional distribution
are stored in the leading dimension"""
normed_obj_array = obj_array(len(obj_arr))
for i, arr in enumerate(obj_arr):
normed_obj_array[i] = norm_dist(arr)
Expand Down

0 comments on commit 213a6ef

Please sign in to comment.