From 5f1e1dc3568b8826be516e75f58bbdfdd70fa0d4 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 20 Sep 2023 22:23:12 +0800 Subject: [PATCH 1/4] [bug] compatible with `jax>=0.4.16` --- .../_src/math/object_transform/autograd.py | 12 +++--- .../math/op_registers/tests/test_ei_net.py | 41 ++++++++----------- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 97f26712c..f8164e615 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -915,9 +915,8 @@ def _valid_jaxtype(arg): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): - raise TypeError( - f"{name} with output element type {aval.dtype.name}") + if jnp.issubdtype(aval.dtype, dtypes.extended): + raise TypeError(f"{name} with output element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " @@ -938,9 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x): def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): _check_arg(x) aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): - raise TypeError( - f"{name} with input element type {aval.dtype.name}") + if jnp.issubdtype(aval.dtype, dtypes.extended): + raise TypeError(f"{name} with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " @@ -972,7 +970,7 @@ def _check_output_dtype_jacfwd(holomorphic, x): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: _check_arg(x) aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): + if jnp.issubdtype(aval.dtype, dtypes.extended): raise TypeError(f"jacfwd with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): diff --git a/brainpy/_src/math/op_registers/tests/test_ei_net.py b/brainpy/_src/math/op_registers/tests/test_ei_net.py index 24a1a6a6c..817d26de7 100644 --- a/brainpy/_src/math/op_registers/tests/test_ei_net.py +++ b/brainpy/_src/math/op_registers/tests/test_ei_net.py @@ -1,9 +1,6 @@ import brainpy.math as bm import brainpy as bp -from jax.core import ShapedArray - - -bm.set_platform('cpu') +from jax import ShapedArray def abs_eval(events, indices, indptr, *, weight, post_num): @@ -25,7 +22,7 @@ def con_compute(outs, ins): event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute) -class ExponentialV2(bp.TwoEndConn): +class ExponentialV2(bp.synapses.TwoEndConn): """Exponential synapse model using customized operator written in C++.""" def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): @@ -46,8 +43,8 @@ def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): # function self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') - def update(self, tdi): - self.g.value = self.integral(self.g, tdi.t, tdi.dt) + def update(self): + self.g.value = self.integral(self.g, bp.share['t']) self.g += event_sum(self.pre.spike, self.pre2post[0], self.pre2post[1], @@ -56,31 +53,25 @@ def update(self, tdi): self.post.input += self.g * (self.E - self.post.V) -class EINet(bp.Network): +class EINet(bp.DynSysGroup): def __init__(self, scale): + super().__init__() # neurons - bm.random.seed() pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto') - I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto') + V_initializer=bp.init.Normal(-55., 2.), method='exp_auto') + self.E = bp.neurons.LIF(int(3200 * scale), **pars) + self.I = bp.neurons.LIF(int(800 * scale), **pars) # synapses - E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - - super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) def test1(): - bm.random.seed() + bm.set_platform('cpu') net2 = EINet(scale=0.1) - runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) - r = runner2.predict(100., eval_time=True) + runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) + r = runner.predict(100., eval_time=True) bm.clear_buffer_memory() - - - - From 1db3cee31e5ce6aba0e7ecb9410060c31aae7b85 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 21 Sep 2023 16:05:33 +0800 Subject: [PATCH 2/4] fix --- brainpy/_src/math/op_registers/tests/test_ei_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/op_registers/tests/test_ei_net.py b/brainpy/_src/math/op_registers/tests/test_ei_net.py index 817d26de7..28d106cb2 100644 --- a/brainpy/_src/math/op_registers/tests/test_ei_net.py +++ b/brainpy/_src/math/op_registers/tests/test_ei_net.py @@ -1,6 +1,6 @@ import brainpy.math as bm import brainpy as bp -from jax import ShapedArray +from jax.core import ShapedArray def abs_eval(events, indices, indptr, *, weight, post_num): From 5aed3bd2342bfd640518c95a287dcb8777a223fd Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 21 Sep 2023 16:13:00 +0800 Subject: [PATCH 3/4] update requirements --- requirements-dev.txt | 4 ++-- requirements-doc.txt | 4 ++-- requirements.txt | 2 +- setup.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 01184540a..49fa49722 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ numpy numba brainpylib -jax>=0.4.1 -jaxlib>=0.4.1 +jax>=0.4.1, <0.4.16 +jaxlib>=0.4.1, <0.4.16 matplotlib>=3.4 msgpack tqdm diff --git a/requirements-doc.txt b/requirements-doc.txt index d88a0c02a..e6e498937 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -2,8 +2,8 @@ numpy tqdm msgpack numba -jax>=0.4.1 -jaxlib>=0.4.1 +jax>=0.4.1, <0.4.16 +jaxlib>=0.4.1, <0.4.16 matplotlib>=3.4 scipy>=1.1.0 numba diff --git a/requirements.txt b/requirements.txt index 74db0a68a..ebf85b86e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -jax>=0.4.1 +jax>=0.4.1, <0.4.16 tqdm msgpack numba \ No newline at end of file diff --git a/setup.py b/setup.py index 343ca3a89..68debcdee 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.1', 'tqdm', 'msgpack', 'numba'], + install_requires=['numpy>=1.15', 'jax>=0.4.1, <0.4.16', 'tqdm', 'msgpack', 'numba'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", From 8da1271701d6af67553ad7b73cf630e646069d09 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 21 Sep 2023 17:34:50 +0800 Subject: [PATCH 4/4] updates --- brainpy/_src/math/object_transform/autograd.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f8164e615..5f06b4e67 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -915,8 +915,8 @@ def _valid_jaxtype(arg): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with output element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with output element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " @@ -937,8 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x): def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " @@ -970,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"jacfwd with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError("jacfwd with holomorphic=True requires inputs with complex "