Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] fix compatible bug #508

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,8 @@ def _valid_jaxtype(arg):

def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with output element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
Expand All @@ -938,9 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
Expand Down Expand Up @@ -972,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x):
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
Expand Down
39 changes: 15 additions & 24 deletions brainpy/_src/math/op_registers/tests/test_ei_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from jax.core import ShapedArray


bm.set_platform('cpu')


def abs_eval(events, indices, indptr, *, weight, post_num):
return [ShapedArray((post_num,), bm.float32), ]

Expand All @@ -25,7 +22,7 @@ def con_compute(outs, ins):
event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute)


class ExponentialV2(bp.TwoEndConn):
class ExponentialV2(bp.synapses.TwoEndConn):
"""Exponential synapse model using customized operator written in C++."""

def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
Expand All @@ -46,8 +43,8 @@ def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto')

def update(self, tdi):
self.g.value = self.integral(self.g, tdi.t, tdi.dt)
def update(self):
self.g.value = self.integral(self.g, bp.share['t'])
self.g += event_sum(self.pre.spike,
self.pre2post[0],
self.pre2post[1],
Expand All @@ -56,31 +53,25 @@ def update(self, tdi):
self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.Network):
class EINet(bp.DynSysGroup):
def __init__(self, scale):
super().__init__()
# neurons
bm.random.seed()
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto')
I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto')
V_initializer=bp.init.Normal(-55., 2.), method='exp_auto')
self.E = bp.neurons.LIF(int(3200 * scale), **pars)
self.I = bp.neurons.LIF(int(800 * scale), **pars)

# synapses
E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)

super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)


def test1():
bm.random.seed()
bm.set_platform('cpu')
net2 = EINet(scale=0.1)
runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
r = runner2.predict(100., eval_time=True)
runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
r = runner.predict(100., eval_time=True)
bm.clear_buffer_memory()




4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy
numba
brainpylib
jax>=0.4.1
jaxlib>=0.4.1
jax>=0.4.1, <0.4.16
jaxlib>=0.4.1, <0.4.16
matplotlib>=3.4
msgpack
tqdm
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ numpy
tqdm
msgpack
numba
jax>=0.4.1
jaxlib>=0.4.1
jax>=0.4.1, <0.4.16
jaxlib>=0.4.1, <0.4.16
matplotlib>=3.4
scipy>=1.1.0
numba
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
jax>=0.4.1
jax>=0.4.1, <0.4.16
tqdm
msgpack
numba
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.8',
install_requires=['numpy>=1.15', 'jax>=0.4.1', 'tqdm', 'msgpack', 'numba'],
install_requires=['numpy>=1.15', 'jax>=0.4.1, <0.4.16', 'tqdm', 'msgpack', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down