Skip to content

Commit

Permalink
Update random.py
Browse files Browse the repository at this point in the history
use `lax.nextafter` to get small values.
  • Loading branch information
charlielam0615 committed Jan 4, 2024
1 parent 2f33156 commit a59be4b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def truncated_normal(self, lower, upper, size=None, loc=0., scale=1., dtype=floa
# [2l-1, 2u-1].
key = self.split_key() if key is None else _formalize_key(key)
# add a small value to avoid inf values after lax.erf_inv
out = jr.uniform(key, size, dtype, minval=2 * l - 1 + 1e-7, maxval=2 * u - 1 - 1e-7)
eps = lax.nextafter(0., np.array(np.inf, dtype=dtype))
out = jr.uniform(key, size, dtype, minval=2 * l - 1 + eps , maxval=2 * u - 1 - eps)


# Use inverse cdf transform for normal distribution to get truncated
# standard normal
Expand Down

0 comments on commit a59be4b

Please sign in to comment.