diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 702e845c..d7b6d434 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -829,7 +829,9 @@ def truncated_normal(self, lower, upper, size=None, loc=0., scale=1., dtype=floa # [2l-1, 2u-1]. key = self.split_key() if key is None else _formalize_key(key) # add a small value to avoid inf values after lax.erf_inv - out = jr.uniform(key, size, dtype, minval=2 * l - 1 + 1e-7, maxval=2 * u - 1 - 1e-7) + eps = lax.nextafter(0., np.array(np.inf, dtype=dtype)) + out = jr.uniform(key, size, dtype, minval=2 * l - 1 + eps , maxval=2 * u - 1 - eps) + # Use inverse cdf transform for normal distribution to get truncated # standard normal