Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Refactor taichi operators #598

Merged
merged 27 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b24a544
[dnn] Add dnn.linear taichi implmentation
Routhleck Jan 22, 2024
65305b2
[math] Remove multiple results of event csrmv and csrmv
Routhleck Jan 22, 2024
8710ad8
[dnn] Fix bugs
Routhleck Jan 22, 2024
61c2a07
[dnn] Update jitconn event atomic=True
Routhleck Jan 24, 2024
f1ac501
[dnn] Replace brainpylib opeartors with taichi customized operators
Routhleck Jan 25, 2024
7499c1e
Update linear.py
Routhleck Jan 25, 2024
be87e9f
Update test_linear.py
Routhleck Jan 25, 2024
ae95fad
[dnn, math] Fix bugs
Routhleck Jan 25, 2024
b993e13
[math] Fix bugs
Routhleck Jan 25, 2024
420cbba
Update linear.py
Routhleck Jan 26, 2024
33f21b9
Refactor operators
Routhleck Jan 28, 2024
bcd9afb
[math] Fix bugs
Routhleck Jan 28, 2024
ee018b0
[dnn] Fix bugs
Routhleck Jan 28, 2024
85afa71
[math] Fix bugs
Routhleck Jan 28, 2024
ff846ae
[math] Fix jitconn matvec bugs
Routhleck Jan 28, 2024
97f7e7a
Update linear.py
Routhleck Jan 28, 2024
8a517f0
[math] Update operators
Routhleck Jan 29, 2024
046dbea
[math] Update pytests
Routhleck Jan 29, 2024
df8c0bf
[math] Fix pytest bugs
Routhleck Jan 29, 2024
c0d7561
Update test_csrmv.py
Routhleck Jan 29, 2024
b49ae90
Update test_matvec.py
Routhleck Jan 29, 2024
21b8426
Update test_event_matvec.py
Routhleck Jan 29, 2024
c43bead
Update test_event_csrmv.py
Routhleck Jan 29, 2024
efc8923
[math] Update pytests
Routhleck Jan 29, 2024
c027817
[math] Fix test case bugs
Routhleck Jan 29, 2024
12d045d
[math] Add more tolerance for jitconn operators
Routhleck Jan 29, 2024
e1f4005
format the code
chaoming0625 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __init__(
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
method: str = 'cusparse',
method: str = None,
transpose: bool = True,
):
super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
Expand All @@ -580,8 +580,7 @@ def update(self, x):
if x.ndim == 1:
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)
method=self.method, transpose=self.transpose)
elif x.ndim > 1:
shapes = x.shape[:-1]
x = bm.flatten(x, end_dim=-2)
Expand All @@ -593,9 +592,7 @@ def update(self, x):
def _batch_csrmv(self, x):
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)

method=self.method, transpose=self.transpose)

class EventCSRLinear(_CSRLayer):
r"""Synaptic matrix multiplication with event CSR sparse computation.
Expand Down Expand Up @@ -646,7 +643,6 @@ def _batch_csrmv(self, x):
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -659,7 +655,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w
# out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)


csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update)


Expand All @@ -671,7 +666,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -697,6 +691,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]



class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.

Expand Down Expand Up @@ -1080,7 +1075,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand Down Expand Up @@ -1161,7 +1156,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand Down Expand Up @@ -1239,7 +1234,7 @@ def __init__(
seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))
bm.clear_buffer_memory()


if __name__ == '__main__':
absltest.main()
1 change: 0 additions & 1 deletion brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from ._info_collection import *
from ._csr_matvec import *
from ._csr_matvec_taichi import *

Loading
Loading