Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 8, 2024
1 parent be405e9 commit 4596c5e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
66 changes: 55 additions & 11 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,21 @@ def raw_csrmv_taichi(
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
Expand Down Expand Up @@ -442,6 +448,30 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
def _event_csr_matvec_dW_transpose(values: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]


@ti.kernel
def _event_csr_matvec_dW_bool_transpose(values: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]


@ti.kernel
def _event_csr_matvec_dW(values: ti.types.ndarray(),
indices: ti.types.ndarray(),
Expand All @@ -451,7 +481,7 @@ def _event_csr_matvec_dW(values: ti.types.ndarray(),
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]
out[indices[j]] += values[j] * events[i]


@ti.kernel
Expand All @@ -463,11 +493,11 @@ def _event_csr_matvec_dW_bool(values: ti.types.ndarray(),
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]
out[indices[j]] += values[j]


def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
return raw_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)


def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand All @@ -492,12 +522,18 @@ def _event_csr_matvec_transpose_taichi(
else: # heterogeneous values
# row, col = csr_to_coo(indices, indptr)
# ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
if events.dtype == jnp.bool_:
ct_values = _event_csr_matvec_dW_bool(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0]
if transpose:
if events.dtype == jnp.bool_:
prim = _event_csrmv_dW_bool_p_transpose
else:
prim = _event_csrmv_dW_p_transpose
else:
ct_values = _event_csr_matvec_dW(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0]
if events.dtype == jnp.bool_:
prim = _event_csrmv_dW_bool_p
else:
prim = _event_csrmv_dW_p
ct_values = prim(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.aval.shape[0],), values.aval.dtype)])[0]
return ct_values, indices, indptr, events


Expand Down Expand Up @@ -540,10 +576,18 @@ def _define_op(cpu_kernel, gpu_kernel):
_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu,
_event_csr_matvec_heter_gpu)

# compute dW transpose
_event_csrmv_dW_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_transpose,
gpu_kernel=_event_csr_matvec_dW_transpose)

# compute dW bool transpose
_event_csrmv_dW_bool_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool_transpose,
gpu_kernel=_event_csr_matvec_dW_bool_transpose)

# compute dW
_event_csrmv_dW_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW,
gpu_kernel=_event_csr_matvec_dW)

# compute dW bool
_event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool,
gpu_kernel=_event_csr_matvec_dW_bool)
gpu_kernel=_event_csr_matvec_dW_bool)
14 changes: 6 additions & 8 deletions brainpy/_src/math/event/tests/event_dW_VS_normal_dW.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# from jax_taichi import jax_taichi_call

import os
import time
from functools import partial
import os

import brainpy as bp
import brainpy.math as bm
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import taichi as ti

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

Expand Down Expand Up @@ -80,7 +78,7 @@ def normal_csrmv_grad(weight, indices, indptr, vector, shape, transpose):
return r


def test_event_csrmv_dW(shape, values_type, events_type, transpose):
def _test_event_csrmv_dW(shape, values_type, events_type, transpose):
rng = bm.random.RandomState(1234)
indices, indptr = bp.conn.FixedProb(p, seed=1234, allow_multi_conn=True)(*shape).require('pre2post')
vector = rng.random(shape[0] if transpose else shape[1]) < 0.1
Expand Down Expand Up @@ -259,7 +257,7 @@ def test_event_csrmv_dW(shape, values_type, events_type, transpose):
event_time_1, event_time_2, event_time_3, event_time_4, event_time_5, \
event_time_6, event_time_7, event_time_8, event_time_9, event_time_10, \
normal_time_1, normal_time_2, normal_time_3, normal_time_4, normal_time_5, \
normal_time_6, normal_time_7, normal_time_8, normal_time_9, normal_time_10 = test_event_csrmv_dW(
normal_time_6, normal_time_7, normal_time_8, normal_time_9, normal_time_10 = _test_event_csrmv_dW(
(shape1, shape2), _values_type, _events_type, _transpose)
# append to dataframe
df.loc[df.shape[0]] = [(shape1, shape2), 0.5, shape1, shape2, bm.get_platform(), _values_type, _events_type, _transpose,
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)
# if platform.system() == 'Windows' and not force_test:
# pytest.skip('skip windows', allow_module_level=True)


seed = 1234
Expand Down Expand Up @@ -218,6 +218,7 @@ def test_heter_grad(self, shape, transpose):
data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = jax.grad(sum_op(bm.event.csrmv))(
data, indices, indptr, events, shape=shape, transpose=transpose)
print(r1 - r2)
self.assertTrue(bm.allclose(r1, r2))

# grad 'events'
Expand Down

0 comments on commit 4596c5e

Please sign in to comment.