Skip to content

Commit

Permalink
[math] taichi operators as default customized operators (#598)
Browse files Browse the repository at this point in the history
* [dnn] Add dnn.linear taichi implmentation

* [math] Remove multiple results of event csrmv and csrmv

* [dnn] Fix bugs

* [dnn] Update jitconn event atomic=True

* [dnn] Replace brainpylib opeartors with taichi customized operators

* Update linear.py

* Update test_linear.py

* [dnn, math] Fix bugs

* [math] Fix bugs

* Update linear.py

* Refactor operators

* [math] Fix bugs

* [dnn] Fix bugs

* [math] Fix bugs

* [math] Fix jitconn matvec bugs

* Update linear.py

* [math] Update operators

* [math] Update pytests

* [math] Fix pytest bugs

* Update test_csrmv.py

* Update test_matvec.py

* Update test_event_matvec.py

* Update test_event_csrmv.py

* [math] Update pytests

* [math] Fix test case bugs

* [math] Add more tolerance for jitconn operators

* format the code

---------

Co-authored-by: Chaoming Wang <[email protected]>
  • Loading branch information
Routhleck and chaoming0625 authored Jan 29, 2024
1 parent 8c57f66 commit 7e8dd81
Show file tree
Hide file tree
Showing 31 changed files with 5,368 additions and 5,390 deletions.
19 changes: 7 additions & 12 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __init__(
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
method: str = 'cusparse',
method: str = None,
transpose: bool = True,
):
super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
Expand All @@ -580,8 +580,7 @@ def update(self, x):
if x.ndim == 1:
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)
method=self.method, transpose=self.transpose)
elif x.ndim > 1:
shapes = x.shape[:-1]
x = bm.flatten(x, end_dim=-2)
Expand All @@ -593,9 +592,7 @@ def update(self, x):
def _batch_csrmv(self, x):
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)

method=self.method, transpose=self.transpose)

class EventCSRLinear(_CSRLayer):
r"""Synaptic matrix multiplication with event CSR sparse computation.
Expand Down Expand Up @@ -646,7 +643,6 @@ def _batch_csrmv(self, x):
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -659,7 +655,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w
# out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)


csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update)


Expand All @@ -671,7 +666,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -697,6 +691,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]



class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.
Expand Down Expand Up @@ -1080,7 +1075,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand Down Expand Up @@ -1161,7 +1156,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand Down Expand Up @@ -1239,7 +1234,7 @@ def __init__(
seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))
bm.clear_buffer_memory()


if __name__ == '__main__':
absltest.main()
1 change: 0 additions & 1 deletion brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from ._info_collection import *
from ._csr_matvec import *
from ._csr_matvec_taichi import *

Loading

0 comments on commit 7e8dd81

Please sign in to comment.