Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean taichi AOT caches #643

Merged
merged 8 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/_src/_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def register_entry(
delay_type = 'homo'
else:
delay_type = 'heter'
delay_step = bm.Array(delay_step)
delay_step = delay_step
elif callable(delay_step):
delay_step = delay_step(self.delay_target_shape)
delay_type = 'heter'
Expand Down
214 changes: 108 additions & 106 deletions brainpy/_src/connect/tests/test_random_conn_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,176 +2,178 @@

import pytest

pytest.skip('skip', allow_module_level=True)

import brainpy as bp


def test_random_fix_pre1():
for num in [0.4, 20]:
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat1 = conn1.require(bp.connect.CONN_MAT)
for num in [0.4, 20]:
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat1 = conn1.require(bp.connect.CONN_MAT)

conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat2 = conn2.require(bp.connect.CONN_MAT)
conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat2 = conn2.require(bp.connect.CONN_MAT)

print()
print(f'num = {num}')
print('conn_mat 1\n', mat1)
print(mat1.sum())
print('conn_mat 2\n', mat2)
print(mat2.sum())
print()
print(f'num = {num}')
print('conn_mat 1\n', mat1)
print(mat1.sum())
print('conn_mat 2\n', mat2)
print(mat2.sum())

assert bp.math.array_equal(mat1, mat2)
bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)
assert bp.math.array_equal(mat1, mat2)
bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)


def test_random_fix_pre2():
for num in [0.5, 3]:
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4)
mat1 = conn1.require(bp.connect.CONN_MAT)
print()
print(mat1)
for num in [0.5, 3]:
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4)
mat1 = conn1.require(bp.connect.CONN_MAT)
print()
print(mat1)

bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num)
bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num)


def test_random_fix_pre3():
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)

bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4')
bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4')


def test_random_fix_post1():
for num in [0.4, 20]:
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat1 = conn1.require(bp.connect.CONN_MAT)
for num in [0.4, 20]:
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat1 = conn1.require(bp.connect.CONN_MAT)

conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat2 = conn2.require(bp.connect.CONN_MAT)
conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
mat2 = conn2.require(bp.connect.CONN_MAT)

print()
print('conn_mat 1\n', mat1)
print('conn_mat 2\n', mat2)
print()
print('conn_mat 1\n', mat1)
print('conn_mat 2\n', mat2)

assert bp.math.array_equal(mat1, mat2)
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)
assert bp.math.array_equal(mat1, mat2)
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)


def test_random_fix_post2():
for num in [0.5, 3]:
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4)
mat1 = conn1.require(bp.connect.CONN_MAT)
print(mat1)
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num)
for num in [0.5, 3]:
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4)
mat1 = conn1.require(bp.connect.CONN_MAT)
print(mat1)
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num)


def test_random_fix_post3():
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)
bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4')
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)
bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4')


def test_gaussian_prob1():
conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
mat = conn.require(bp.connect.CONN_MAT)
conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
mat = conn.require(bp.connect.CONN_MAT)

print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100')
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100')


def test_gaussian_prob2():
conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50))
mat = conn.require(bp.connect.CONN_MAT)
conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50))
mat = conn.require(bp.connect.CONN_MAT)

print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)')
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)')


def test_gaussian_prob3():
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50))
mat = conn.require(bp.connect.CONN_MAT)
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50))
mat = conn.require(bp.connect.CONN_MAT)

print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)')
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)')


def test_gaussian_prob4():
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10))
conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
mat = conn.require(bp.connect.CONN_MAT)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)')
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10))
conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
mat = conn.require(bp.connect.CONN_MAT)
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)')


def test_SmallWorld1():
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=10, post_size=10)
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=10, post_size=10)

mat = conn.require(bp.connect.CONN_MAT)
mat = conn.require(bp.connect.CONN_MAT)

print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10')
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10')


def test_SmallWorld3():
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True)
conn(pre_size=20, post_size=20)
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True)
conn(pre_size=20, post_size=20)

mat = conn.require(bp.connect.CONN_MAT)
mat = conn.require(bp.connect.CONN_MAT)

print('conn_mat', mat)
print('conn_mat', mat)

bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20')
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20')


def test_SmallWorld2():
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
conn(pre_size=(100,), post_size=(100,))
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
conn(pre_size=(100,), post_size=(100,))
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)')


def test_ScaleFreeBA():
conn = bp.connect.ScaleFreeBA(m=2)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)')


def test_ScaleFreeBA():
conn = bp.connect.ScaleFreeBA(m=2)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size))
bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size))


def test_ScaleFreeBADual():
conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))
conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))


def test_PowerLaw():
conn = bp.connect.PowerLaw(m=3, p=0.4)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))
conn = bp.connect.PowerLaw(m=3, p=0.4)
for size in [100, (10, 20), (2, 10, 20)]:
conn(pre_size=size, post_size=size)
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))
1 change: 0 additions & 1 deletion brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
# operators
from .op_register import *
from .pre_syn_post import *
from .surrogate._compt import *
from . import surrogate, event, sparse, jitconn

# Variable and Objects for object-oriented JAX transformations
Expand Down
28 changes: 16 additions & 12 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
_max = max


def _return(a):
return Array(a)


def fill_diagonal(a, val, inplace=True):
if a.ndim < 2:
raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}')
Expand All @@ -120,30 +124,30 @@ def fill_diagonal(a, val, inplace=True):


def zeros(shape, dtype=None):
return Array(jnp.zeros(shape, dtype=dtype))
return _return(jnp.zeros(shape, dtype=dtype))


def ones(shape, dtype=None):
return Array(jnp.ones(shape, dtype=dtype))
return _return(jnp.ones(shape, dtype=dtype))


def empty(shape, dtype=None):
return Array(jnp.zeros(shape, dtype=dtype))
return _return(jnp.zeros(shape, dtype=dtype))


def zeros_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.zeros_like(a, dtype=dtype, shape=shape))
return _return(jnp.zeros_like(a, dtype=dtype, shape=shape))


def ones_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.ones_like(a, dtype=dtype, shape=shape))
return _return(jnp.ones_like(a, dtype=dtype, shape=shape))


def empty_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.zeros_like(a, dtype=dtype, shape=shape))
return _return(jnp.zeros_like(a, dtype=dtype, shape=shape))


def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array:
Expand All @@ -155,7 +159,7 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array:
leaves = [_as_jax_array_(l) for l in leaves]
a = tree_unflatten(tree, leaves)
res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin)
return Array(res)
return _return(res)


def asarray(a, dtype=None, order=None):
Expand All @@ -167,29 +171,29 @@ def asarray(a, dtype=None, order=None):
leaves = [_as_jax_array_(l) for l in leaves]
arrays = tree_unflatten(tree, leaves)
res = jnp.asarray(a=arrays, dtype=dtype, order=order)
return Array(res)
return _return(res)


def arange(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
return Array(jnp.arange(*args, **kwargs))
return _return(jnp.arange(*args, **kwargs))


def linspace(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
res = jnp.linspace(*args, **kwargs)
if isinstance(res, tuple):
return Array(res[0]), res[1]
return _return(res[0]), res[1]
else:
return Array(res)
return _return(res)


def logspace(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
return Array(jnp.logspace(*args, **kwargs))
return _return(jnp.logspace(*args, **kwargs))


def asanyarray(a, dtype=None, order=None):
Expand Down
Loading
Loading