From 4596c5e2c5a10a564456c8f8eea5b961e022d6d0 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 8 Jun 2024 08:10:30 +0800 Subject: [PATCH] Update --- brainpy/_src/math/event/csr_matvec.py | 66 +++++++++++++++---- .../math/event/tests/event_dW_VS_normal_dW.py | 14 ++-- .../_src/math/event/tests/test_event_csrmv.py | 5 +- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index 5dda9504..d579f14b 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -131,7 +131,10 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_transpose_bool_heter_p else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) + if data.shape[0] == 1: + prim = _event_csrmv_transpose_homo_p + else: + prim = _event_csrmv_transpose_heter_p else: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -139,7 +142,10 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_bool_heter_p else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) + if data.shape[0] == 1: + prim = _event_csrmv_homo_p + else: + prim = _event_csrmv_heter_p # computing return prim(data, @@ -442,6 +448,30 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), out[row_i] += r # TODO: warp-level primitive + @ti.kernel + def _event_csr_matvec_dW_transpose(values: ti.types.ndarray(), + indices: ti.types.ndarray(), + indptr: ti.types.ndarray(), + events: ti.types.ndarray(), + out: ti.types.ndarray()): + for i in range(events.shape[0]): + if events[i] != 0.: + for j in range(indptr[i], indptr[i + 1]): + out[j] = values[indices[j]] + + + @ti.kernel + def _event_csr_matvec_dW_bool_transpose(values: ti.types.ndarray(), + indices: ti.types.ndarray(), + indptr: ti.types.ndarray(), + events: ti.types.ndarray(), + out: ti.types.ndarray()): + for i in range(events.shape[0]): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + out[j] = values[indices[j]] + + @ti.kernel def _event_csr_matvec_dW(values: ti.types.ndarray(), indices: ti.types.ndarray(), @@ -451,7 +481,7 @@ def _event_csr_matvec_dW(values: ti.types.ndarray(), for i in range(events.shape[0]): if events[i] != 0.: for j in range(indptr[i], indptr[i + 1]): - out[j] = values[indices[j]] + out[indices[j]] += values[j] * events[i] @ti.kernel @@ -463,11 +493,11 @@ def _event_csr_matvec_dW_bool(values: ti.types.ndarray(), for i in range(events.shape[0]): if events[i]: for j in range(indptr[i], indptr[i + 1]): - out[j] = values[indices[j]] + out[indices[j]] += values[j] def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) + return raw_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): @@ -492,12 +522,18 @@ def _event_csr_matvec_transpose_taichi( else: # heterogeneous values # row, col = csr_to_coo(indices, indptr) # ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - if events.dtype == jnp.bool_: - ct_values = _event_csr_matvec_dW_bool(ct[0], indices, indptr, events, - outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0] + if transpose: + if events.dtype == jnp.bool_: + prim = _event_csrmv_dW_bool_p_transpose + else: + prim = _event_csrmv_dW_p_transpose else: - ct_values = _event_csr_matvec_dW(ct[0], indices, indptr, events, - outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0] + if events.dtype == jnp.bool_: + prim = _event_csrmv_dW_bool_p + else: + prim = _event_csrmv_dW_p + ct_values = prim(ct[0], indices, indptr, events, + outs=[jax.ShapeDtypeStruct((values.aval.shape[0],), values.aval.dtype)])[0] return ct_values, indices, indptr, events @@ -540,10 +576,18 @@ def _define_op(cpu_kernel, gpu_kernel): _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + # compute dW transpose + _event_csrmv_dW_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_transpose, + gpu_kernel=_event_csr_matvec_dW_transpose) + + # compute dW bool transpose + _event_csrmv_dW_bool_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool_transpose, + gpu_kernel=_event_csr_matvec_dW_bool_transpose) + # compute dW _event_csrmv_dW_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW, gpu_kernel=_event_csr_matvec_dW) # compute dW bool _event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool, - gpu_kernel=_event_csr_matvec_dW_bool) + gpu_kernel=_event_csr_matvec_dW_bool) \ No newline at end of file diff --git a/brainpy/_src/math/event/tests/event_dW_VS_normal_dW.py b/brainpy/_src/math/event/tests/event_dW_VS_normal_dW.py index a99c74f7..b5601696 100644 --- a/brainpy/_src/math/event/tests/event_dW_VS_normal_dW.py +++ b/brainpy/_src/math/event/tests/event_dW_VS_normal_dW.py @@ -1,16 +1,14 @@ # from jax_taichi import jax_taichi_call +import os import time from functools import partial -import os -import brainpy as bp -import brainpy.math as bm import jax -import jax.numpy as jnp -import numpy as np import pandas as pd -import taichi as ti + +import brainpy as bp +import brainpy.math as bm bm.set_platform('cpu') @@ -80,7 +78,7 @@ def normal_csrmv_grad(weight, indices, indptr, vector, shape, transpose): return r -def test_event_csrmv_dW(shape, values_type, events_type, transpose): +def _test_event_csrmv_dW(shape, values_type, events_type, transpose): rng = bm.random.RandomState(1234) indices, indptr = bp.conn.FixedProb(p, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 @@ -259,7 +257,7 @@ def test_event_csrmv_dW(shape, values_type, events_type, transpose): event_time_1, event_time_2, event_time_3, event_time_4, event_time_5, \ event_time_6, event_time_7, event_time_8, event_time_9, event_time_10, \ normal_time_1, normal_time_2, normal_time_3, normal_time_4, normal_time_5, \ - normal_time_6, normal_time_7, normal_time_8, normal_time_9, normal_time_10 = test_event_csrmv_dW( + normal_time_6, normal_time_7, normal_time_8, normal_time_9, normal_time_10 = _test_event_csrmv_dW( (shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe df.loc[df.shape[0]] = [(shape1, shape2), 0.5, shape1, shape2, bm.get_platform(), _values_type, _events_type, _transpose, diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 181ee552..1feef052 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -16,8 +16,8 @@ import platform force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) +# if platform.system() == 'Windows' and not force_test: +# pytest.skip('skip windows', allow_module_level=True) seed = 1234 @@ -218,6 +218,7 @@ def test_heter_grad(self, shape, transpose): data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) + print(r1 - r2) self.assertTrue(bm.allclose(r1, r2)) # grad 'events'