Skip to content

Commit

Permalink
Test event csr matvec using taichi customized op
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 30, 2023
1 parent 26e9820 commit 613f28b
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 8 deletions.
44 changes: 38 additions & 6 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import (XLACustomOp,
register_general_batching)
from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv
from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo
from brainpy._src.dependency_check import (import_brainpylib_cpu_ops,
import_brainpylib_gpu_ops)
Expand Down Expand Up @@ -207,8 +207,8 @@ def _event_csr_matvec_gpu(values: ti.types.ndarray(),



def _event_matvec_jvp(
primals, tangents,
def _event_csr_matvec_jvp(
primals, tangents, *, outs
):
values, indices, indptr, events, bool_param_list, shape_list = primals
values_dot, indices_dot, indptr_dot, events_dot, bool_param_list_dot, shape_list_dot = tangents
Expand All @@ -222,15 +222,20 @@ def _event_matvec_jvp(
outs=[jax.ShapeDtypeStruct(shape=(shape_list[1] if bool_param_list[0] else shape_list[0],),
dtype=values.dtype)])

assert type(values_dot) is ad.Zero
assert type(indices_dot) is ad.Zero
assert type(indptr_dot) is ad.Zero
assert type(events_dot) is ad.Zero
assert type(bool_param_list_dot) is ad.Zero
assert type(shape_list_dot) is ad.Zero

if type(values_dot) is ad.Zero:
if type(events_dot) is ad.Zero:
raise ValueError
# TODO: implement sparse csr matvec first
dr = normal_csrmv_taichi(values,
indices,
indptr,
events_dot,
shape=shape_list,
transpose=bool_param_list[0])

elif type(events_dot) is ad.Zero:
dr = _event_csr_matvec_p(values_dot,
Expand All @@ -245,6 +250,31 @@ def _event_matvec_jvp(

return r, dr

def _event_csr_matvec_transpose(ct,
values,
indices,
indptr,
events,
bool_param_list,
shape_list,
*,
outs):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(events):
ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape_list, transpose = not bool_param_list[0])[0]
return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events)
else:
if type(ct[0]) is ad.Zero:
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape_list, transpose = bool_param_list[0])[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if bool_param_list[0] else events[col] * ct[0][row]
return ct_values, indices, indptr, events

def csrmv_taichi(
data: Union[float, jax.Array],
Expand Down Expand Up @@ -333,6 +363,8 @@ def csrmv_taichi(
else:
_event_csr_matvec_p = XLACustomOp(cpu_kernel=_event_csr_matvec_cpu,
gpu_kernel=_event_csr_matvec_gpu)
_event_csr_matvec_p.def_jvp_rule(_event_csr_matvec_jvp)
_event_csr_matvec_p.def_transpose_rule(_event_csr_matvec_transpose)

# computing
return _event_csr_matvec_p(data,
Expand Down
67 changes: 67 additions & 0 deletions brainpy/_src/math/event/tests/test_events_csrmv_taichi_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-


from functools import partial

import jax

import brainpy as bp
import brainpy.math as bm
import platform

import pytest

def sum_op(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)
return r.sum()

return func

def sum_op2(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)[0]
return r.sum()
return func

transposes = [True, False]
shapes = [(100, 200),
(200, 200),
(200, 100),
(10, 1000),
(2, 10000),
(1000, 10),
(100000, 2)]
homo_datas = [-1., 0., 1.]

def test_homo_grad(shape, transpose, homo_data):
print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')

rng = bm.random.RandomState()
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)

# grad 'data'
r1 = jax.grad(sum_op(bm.event.csrmv))(homo_data,
indices,
indptr,
events,
shape=shape,
transpose=transpose)

r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))(homo_data,
indices,
indptr,
events,
shape=shape,
transpose=transpose)

assert(bm.allclose(r1, r2))

for transpose in transposes:
for shape in shapes:
for homo_data in homo_datas:
test_homo_grad(shape, transpose, homo_data)
50 changes: 48 additions & 2 deletions brainpy/_src/math/sparse/_csr_mv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import numba
import numpy as np
import taichi as ti
from jax import core, dtypes
from jax import numpy as jnp
from jax.interpreters import ad, mlir, xla
from jax.interpreters import ad, xla
from jax.lib import xla_client
from jaxlib import gpu_sparse

Expand Down Expand Up @@ -116,6 +115,51 @@ def _sparse_csr_matvec_gpu(values: ti.types.ndarray(ndim=1),
r += values[j] * vector[col_indices[j]]
out[row_i] = r

def _sparse_csr_matvec_jvp(
primals, tangents,
):
values, col_indices, row_ptr, vector, shape, transpose = primals
values_dot, col_indices_dot, row_ptr_dot, vector_dot, shape_dot, transpose_dot = tangents

r = _event_csr_matvec_p(values,
col_indices,
row_ptr,
vector,
shape,
transpose,
outs=[jax.ShapeDtypeStruct(
shape=(shape[1] if transpose else shape[0],),
dtype=values.dtype)])

assert type(values_dot) is ad.Zero
assert type(col_indices_dot) is ad.Zero
assert type(row_ptr_dot) is ad.Zero
assert type(vector_dot) is ad.Zero

if type(values_dot) is ad.Zero:
if type(vector_dot) is ad.Zero:
raise ValueError
dr = _event_csr_matvec_p(values,
col_indices,
row_ptr,
vector_dot,
shape,
transpose,
outs=[jax.ShapeDtypeStruct(
shape=(shape[1] if transpose else shape[0],),
dtype=values.dtype)])
elif type(vector_dot) is ad.Zero:
dr = _event_csr_matvec_p(values_dot,
col_indices_dot,
row_ptr_dot,
vector,
shape,
transpose,
outs=[jax.ShapeDtypeStruct(
shape=(shape[1] if transpose else shape[0],),
dtype=values.dtype)])

return r, dr

def csrmv_taichi(
data: Union[float, jnp.ndarray, Array],
Expand Down Expand Up @@ -183,6 +227,8 @@ def csrmv_taichi(
else:
_event_csr_matvec_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu,
gpu_kernel=_sparse_csr_matvec_gpu)

_event_csr_matvec_p.def_jvp_rule(_sparse_csr_matvec_jvp)

shape_list = jnp.array(shape)
is_transpose = jnp.array(transpose)
Expand Down

0 comments on commit 613f28b

Please sign in to comment.