Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Refactor taichi operators #598

Merged
merged 27 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b24a544
[dnn] Add dnn.linear taichi implmentation
Routhleck Jan 22, 2024
65305b2
[math] Remove multiple results of event csrmv and csrmv
Routhleck Jan 22, 2024
8710ad8
[dnn] Fix bugs
Routhleck Jan 22, 2024
61c2a07
[dnn] Update jitconn event atomic=True
Routhleck Jan 24, 2024
f1ac501
[dnn] Replace brainpylib opeartors with taichi customized operators
Routhleck Jan 25, 2024
7499c1e
Update linear.py
Routhleck Jan 25, 2024
be87e9f
Update test_linear.py
Routhleck Jan 25, 2024
ae95fad
[dnn, math] Fix bugs
Routhleck Jan 25, 2024
b993e13
[math] Fix bugs
Routhleck Jan 25, 2024
420cbba
Update linear.py
Routhleck Jan 26, 2024
33f21b9
Refactor operators
Routhleck Jan 28, 2024
bcd9afb
[math] Fix bugs
Routhleck Jan 28, 2024
ee018b0
[dnn] Fix bugs
Routhleck Jan 28, 2024
85afa71
[math] Fix bugs
Routhleck Jan 28, 2024
ff846ae
[math] Fix jitconn matvec bugs
Routhleck Jan 28, 2024
97f7e7a
Update linear.py
Routhleck Jan 28, 2024
8a517f0
[math] Update operators
Routhleck Jan 29, 2024
046dbea
[math] Update pytests
Routhleck Jan 29, 2024
df8c0bf
[math] Fix pytest bugs
Routhleck Jan 29, 2024
c0d7561
Update test_csrmv.py
Routhleck Jan 29, 2024
b49ae90
Update test_matvec.py
Routhleck Jan 29, 2024
21b8426
Update test_event_matvec.py
Routhleck Jan 29, 2024
c43bead
Update test_event_csrmv.py
Routhleck Jan 29, 2024
efc8923
[math] Update pytests
Routhleck Jan 29, 2024
c027817
[math] Fix test case bugs
Routhleck Jan 29, 2024
12d045d
[math] Add more tolerance for jitconn operators
Routhleck Jan 29, 2024
e1f4005
format the code
chaoming0625 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp
import numba
import numpy as np
import taichi as ti
Routhleck marked this conversation as resolved.
Show resolved Hide resolved

from brainpy import math as bm
from brainpy._src import connect, initialize as init
Expand Down Expand Up @@ -570,18 +571,15 @@ def __init__(
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
method: str = 'cusparse',
transpose: bool = True,
):
super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
self.method = method

def update(self, x):
if x.ndim == 1:
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)
transpose=self.transpose)
elif x.ndim > 1:
shapes = x.shape[:-1]
x = bm.flatten(x, end_dim=-2)
Expand All @@ -591,11 +589,9 @@ def update(self, x):
raise ValueError

def _batch_csrmv(self, x):
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose,
method=self.method)

transpose=self.transpose)

class EventCSRLinear(_CSRLayer):
r"""Synaptic matrix multiplication with event CSR sparse computation.
Expand Down Expand Up @@ -630,7 +626,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)
elif x.ndim > 1:
Expand All @@ -642,11 +638,10 @@ def update(self, x):
raise ValueError

def _batch_csrmv(self, x):
return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -659,7 +654,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w
# out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)


csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update)


Expand All @@ -671,7 +665,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]


@numba.njit(nogil=True, fastmath=True, parallel=False)
def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
out_w[:] = w
Expand All @@ -697,6 +690,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]



class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.

Expand Down Expand Up @@ -860,7 +854,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -875,7 +869,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed,
return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand Down Expand Up @@ -940,7 +934,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -955,7 +949,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand Down Expand Up @@ -1020,7 +1014,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -1035,7 +1029,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand Down Expand Up @@ -1080,7 +1074,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand All @@ -1099,7 +1093,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -1114,7 +1108,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed,
return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand Down Expand Up @@ -1161,7 +1155,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
):
super().__init__(name=name, mode=mode)

Expand All @@ -1179,7 +1173,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -1194,7 +1188,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed,
return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand Down Expand Up @@ -1239,7 +1233,7 @@ def __init__(
seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
transpose: bool = False,
atomic: bool = False,
atomic: bool = True,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
Expand All @@ -1259,7 +1253,7 @@ def __init__(

def update(self, x):
if x.ndim == 1:
return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
Expand All @@ -1274,7 +1268,7 @@ def update(self, x):
raise ValueError

def _batch_mv(self, x):
return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed,
return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
1 change: 0 additions & 1 deletion brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))
bm.clear_buffer_memory()


if __name__ == '__main__':
absltest.main()
76 changes: 43 additions & 33 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import XLACustomOp
from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._csr_mv_taichi import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo

ti = import_taichi()
Expand Down Expand Up @@ -333,13 +333,53 @@ def _event_csr_matvec_transpose(
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
return ct_values, indices, indptr, events

def raw_csrmv_taichi(
data: Union[float, jax.Array],
indices: jax.Array,
indptr: jax.Array,
events: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False
):
if transpose:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
indices,
indptr,
events,
outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
transpose=transpose,
shape=shape)

def csrmv_taichi(
data: Union[float, jax.Array],
Expand Down Expand Up @@ -419,37 +459,7 @@ def csrmv_taichi(
if indices.shape[0] == 0:
return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)

if transpose:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
indices,
indptr,
events,
outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
transpose=transpose,
shape=shape)
return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]


def _define_op(cpu_kernel, gpu_kernel):
Expand Down
Loading