Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Replace math operators with braintaichi #698

Merged
merged 12 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
__all__ = [
'import_taichi',
'raise_taichi_not_found',
'import_braintaichi',
'raise_braintaichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
Expand All @@ -16,10 +18,11 @@
]

_minimal_brainpylib_version = '0.2.6'
_minimal_taichi_version = (1, 7, 0)
_minimal_taichi_version = (1, 7, 2)

numba = None
taichi = None
braintaichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
Expand All @@ -33,6 +36,10 @@
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n'
'> pip install braintaichi -U')


os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -69,6 +76,26 @@ def import_taichi(error_if_not_found=True):
def raise_taichi_not_found(*args, **kwargs):
raise ModuleNotFoundError(taichi_install_info)

def import_braintaichi(error_if_not_found=True):
"""Internal API to import braintaichi.

If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global braintaichi
if braintaichi is None:
try:
import braintaichi as braintaichi
except ModuleNotFoundError:
if error_if_not_found:
raise_braintaichi_not_found()
else:
return None
return braintaichi

def raise_braintaichi_not_found():
raise ModuleNotFoundError(braintaichi_install_info)


def import_numba(error_if_not_found=True):
"""
Expand Down
55 changes: 47 additions & 8 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
from brainpy._src.dependency_check import import_taichi
from brainpy._src.dependency_check import import_taichi, import_braintaichi
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
from brainpy.check import is_initializer
Expand All @@ -20,6 +20,7 @@
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding

bti = import_braintaichi(error_if_not_found=False)
ti = import_taichi(error_if_not_found=False)

__all__ = [
Expand Down Expand Up @@ -238,7 +239,7 @@ def update(self, x):
return x


if ti is not None:
if ti is not None and bti is not None:

# @numba.njit(nogil=True, fastmath=True, parallel=False)
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
Expand Down Expand Up @@ -273,7 +274,7 @@ def _dense_on_post(
out_w[i, j] = old_w[i, j]


dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)


# @numba.njit(nogil=True, fastmath=True, parallel=False)
Expand Down Expand Up @@ -309,7 +310,7 @@ def _dense_on_pre(
out_w[i, j] = old_w[i, j]


dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)

else:
dense_on_pre_prim = None
Expand All @@ -326,6 +327,12 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
w_max = np.inf
w_min = jnp.atleast_1d(w_min)
w_max = jnp.atleast_1d(w_max)

weight = bm.as_jax(weight)
spike = bm.as_jax(spike)
trace = bm.as_jax(trace)
w_min = bm.as_jax(w_min)
w_max = bm.as_jax(w_max)
return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]

Expand All @@ -340,6 +347,12 @@ def dense_on_post(weight, spike, trace, w_min, w_max):
w_max = np.inf
w_min = jnp.atleast_1d(w_min)
w_max = jnp.atleast_1d(w_max)

weight = bm.as_jax(weight)
spike = bm.as_jax(spike)
trace = bm.as_jax(trace)
w_min = bm.as_jax(w_min)
w_max = bm.as_jax(w_max)
return dense_on_post_prim(weight, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]

Expand Down Expand Up @@ -735,7 +748,7 @@ def _csr_on_pre_update(
out_w[i_syn] = old_w[i_syn]


csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)


@ti.kernel
Expand All @@ -759,7 +772,7 @@ def _coo_on_pre_update(
out_w[i_syn] = old_w[i_syn]


coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)


@ti.kernel
Expand All @@ -783,7 +796,7 @@ def _coo_on_post_update(
out_w[i_syn] = old_w[i_syn]


coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)


# @numba.njit(nogil=True, fastmath=True, parallel=False)
Expand Down Expand Up @@ -824,7 +837,7 @@ def _csc_on_post_update(
out_w[i_syn] = old_w[i_syn]


csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)


else:
Expand All @@ -843,6 +856,14 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
w_max = np.inf
w_min = jnp.atleast_1d(w_min)
w_max = jnp.atleast_1d(w_max)

w = bm.as_jax(w)
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
spike = bm.as_jax(spike)
trace = bm.as_jax(trace)
w_min = bm.as_jax(w_min)
w_max = bm.as_jax(w_max)
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]

Expand All @@ -857,6 +878,15 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
w_max = np.inf
w_min = jnp.atleast_1d(w_min)
w_max = jnp.atleast_1d(w_max)

w = bm.as_jax(w)
pre_ids = bm.as_jax(pre_ids)
post_ids = bm.as_jax(post_ids)
spike = bm.as_jax(spike)
trace = bm.as_jax(trace)
w_min = bm.as_jax(w_min)
w_max = bm.as_jax(w_max)

return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]

Expand All @@ -871,6 +901,15 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=
w_max = np.inf
w_min = jnp.atleast_1d(w_min)
w_max = jnp.atleast_1d(w_max)

w = bm.as_jax(w)
post_ids = bm.as_jax(post_ids)
indptr = bm.as_jax(indptr)
w_ids = bm.as_jax(w_ids)
post_spike = bm.as_jax(post_spike)
pre_trace = bm.as_jax(pre_trace)
w_min = bm.as_jax(w_min)
w_max = bm.as_jax(w_max)
return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]

Expand Down
13 changes: 5 additions & 8 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import pytest
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp

import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class TestLinear(parameterized.TestCase):
Expand Down Expand Up @@ -104,11 +101,11 @@ def test_CSRLinear(self, conn):
bm.random.seed()
f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))

x = bm.random.random((100,))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()

Expand All @@ -123,10 +120,10 @@ def test_EventCSRLinear(self, conn):
bm.random.seed()
f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))
x = bm.random.random((100,))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()

Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dnn/tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class Test_Conv(parameterized.TestCase):
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dyn/projections/tests/test_STDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

bm.set_platform('cpu')

Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dyn/projections/tests/test_aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dynold.synapses import abstract_models
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class Test_Abstract_Synapse(parameterized.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

biological_models = [
bp.synapses.AMPA,
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
from . import random, linalg, fft, tifunc
from . import random, linalg, fft

# operators
from .op_register import *
Expand Down
Loading