From a59be4b2245fdffe38278cb6047b8745fc0ad3b8 Mon Sep 17 00:00:00 2001 From: charlielam0615 Date: Thu, 4 Jan 2024 16:46:44 +0800 Subject: [PATCH] Update random.py use `lax.nextafter` to get small values. --- brainpy/_src/math/random.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 702e845c..d7b6d434 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -829,7 +829,9 @@ def truncated_normal(self, lower, upper, size=None, loc=0., scale=1., dtype=floa # [2l-1, 2u-1]. key = self.split_key() if key is None else _formalize_key(key) # add a small value to avoid inf values after lax.erf_inv - out = jr.uniform(key, size, dtype, minval=2 * l - 1 + 1e-7, maxval=2 * u - 1 - 1e-7) + eps = lax.nextafter(0., np.array(np.inf, dtype=dtype)) + out = jr.uniform(key, size, dtype, minval=2 * l - 1 + eps , maxval=2 * u - 1 - eps) + # Use inverse cdf transform for normal distribution to get truncated # standard normal