diff --git a/pymdp/jax/maths.py b/pymdp/jax/maths.py index ac8173be..dc41144b 100644 --- a/pymdp/jax/maths.py +++ b/pymdp/jax/maths.py @@ -134,9 +134,8 @@ def spm_wnorm(A): Returns Expectation of logarithm of Dirichlet parameters over a set of Categorical distributions, stored in the columns of A. """ - A = jnp.clip(A, min=MINVAL) norm = 1. / A.sum(axis=0) - avg = 1. / A + avg = 1. / (A + MINVAL) wA = norm - avg return wA