diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 8cdc750b..da3902c8 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -234,33 +234,33 @@ def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), @ti.kernel -def _event_csr_matmat_sum(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=1)): - for row_i in ti.ndrange(out.shape[0]): +def _event_csr_matmat_dw(dy: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + dw: ti.types.ndarray(ndim=1)): + for row_i in ti.ndrange(dw.shape[0]): temp_sum = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): for col_k in range(matrix.shape[1]): if matrix[col_indices[row_j], col_k] != 0.: - temp_sum += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i] = temp_sum + temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k] + dw[row_i] = temp_sum @ti.kernel -def _event_csr_matmat_bool_sum(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=1)): - for row_i in ti.ndrange(out.shape[0]): +def _event_csr_matmat_bool_dw(dy: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + dw: ti.types.ndarray(ndim=1)): + for row_i in ti.ndrange(dw.shape[0]): temp_sum = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): for col_k in range(matrix.shape[1]): if matrix[col_indices[row_j], col_k]: - temp_sum += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i] = temp_sum + temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k] + dw[row_i] = temp_sum def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): @@ -291,13 +291,11 @@ def _event_csr_matmat_transpose( else: # heter matrix = jnp.asarray(matrix) if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_bool_sum + prim = _event_csr_matmat_bool_dw else: - prim = _event_csr_matmat_sum + prim = _event_csr_matmat_dw ct_data = prim(ct[0], indices, indptr, matrix, out=jax.ShapeDtypeStruct((data.aval.shape[0], matrix.shape[1]), data.aval.dtype))[0] - # row, col = csr_to_coo(indices, indptr) - # ct_data = (ct[0][row] * matrix[col]).sum(1) return ct_data, indices, indptr, matrix @@ -340,11 +338,11 @@ def _define_op(cpu_kernel, gpu_kernel): _event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, gpu_kernel=_event_csr_matmat_bool_homo) -_event_csr_matmat_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_sum, - gpu_kernel=_event_csr_matmat_sum) +_event_csr_matmat_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_dw, + gpu_kernel=_event_csr_matmat_dw) -_event_csr_matmat_bool_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_sum, - gpu_kernel=_event_csr_matmat_bool_sum) +_event_csr_matmat_bool_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_dw, + gpu_kernel=_event_csr_matmat_bool_dw) # heter CUSPARSE _csr_matmat_cusparse_p = csr.csr_matmat_p