From 1f485449ce01f8d6a0ac50de5f5d3b683218e91f Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 00:04:04 +0800 Subject: [PATCH] fix random dtype inconsistency --- brainpy/_src/math/__init__.py | 1 - brainpy/_src/math/random.py | 39 +++++++++++++++++++++-------------- brainpy/math/__init__.py | 7 ------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 3102bc1d0..de559de56 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -49,7 +49,6 @@ # operators from .op_register import * from .pre_syn_post import * -from .surrogate._compt import * from . import surrogate, event, sparse, jitconn # Variable and Objects for object-oriented JAX transformations diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index d0f74bf23..9ae012bc4 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1232,9 +1232,10 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) - r = call(lambda x: np.random.zipf(x, size), + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda x: np.random.zipf(x, size).astype(dtype), a, - result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): @@ -1242,8 +1243,10 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option if size is None: size = jnp.shape(a) size = _size2shape(size) - r = call(lambda a: np.random.power(a=a, size=size), - a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = call(lambda a: np.random.power(a=a, size=size).astype(dtype), + a, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, @@ -1256,11 +1259,12 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.f(dfnum=x['dfnum'], dfden=x['dfden'], - size=size), + size=size).astype(dtype), d, - result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, @@ -1274,12 +1278,14 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc jnp.shape(nbad), jnp.shape(nsample)) size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = call(lambda x: np.random.hypergeometric(ngood=x['ngood'], - nbad=x['nbad'], - nsample=x['nsample'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + d, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, @@ -1288,8 +1294,10 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, if size is None: size = jnp.shape(p) size = _size2shape(size) - r = call(lambda p: np.random.logseries(p=p, size=size), - p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + p, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, @@ -1303,11 +1311,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in jnp.shape(nonc)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], dfden=x['dfden'], nonc=x['nonc'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + size=size).astype(dtype), + d, result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) # PyTorch compatibility # diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 9a64f9f25..08a070f02 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -44,13 +44,6 @@ del jnp, config -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr from brainpy._src.dependency_check import import_taichi, import_numba