Skip to content

Commit

Permalink
Update csr_matmat_VS_cusparse_csr_matmat.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 8, 2024
1 parent f355d18 commit c92ace4
Showing 1 changed file with 92 additions and 12 deletions.
104 changes: 92 additions & 12 deletions brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# from jax_taichi import jax_taichi_call

import os
import time
from functools import partial
import os

import brainpy as bp
import brainpy.math as bm
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import taichi as ti
from jax.experimental.sparse import csr

bm.set_platform('cpu')
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import XLACustomOp

SPARSITY = 0.05
ti = import_taichi(error_if_not_found=False)

bm.set_platform('cpu')

size = [
(100, 100, 100),
Expand All @@ -39,27 +40,106 @@
True,
False
]

ITERATION = 10
SPARSITY = 0.05

if bm.get_platform() == 'cpu':
ITERATION = 3

print(bm.get_platform())


@ti.kernel
def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
# matrix: (k, n)
# sparse matrix: (m, k)
n = out.shape[1]
m = row_ptr.shape[0] - 1
for j in range(n): # parallize along the n dimension
for row_i in range(m): # loop along the m dimension
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[i], j] += matrix[row_i, j]


@ti.kernel
def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
m = row_ptr.shape[0] - 1
n = matrix.shape[1]
for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[i], j] += matrix[row_i, j]


@ti.kernel
def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
# matrix: (k, n)
# sparse matrix: (m, k)
m, n = out.shape
for row_i, col_k in ti.ndrange(m, n):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


# transpose homo
_csr_matmat_transpose_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu,
gpu_kernel=_csr_matmat_transpose_homo_gpu)

# no transpose homo
_csr_matmat_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo)


def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose):
indices = as_jax(indices)
indptr = as_jax(indptr)
matrix = as_jax(matrix)
weight = jnp.atleast_1d(weight)
out_shape = shape[1] if transpose else shape[0]
result_shape = (out_shape, matrix.shape[1])
if transpose:
prim = _csr_matmat_transpose_homo_p
else:
prim = _csr_matmat_homo_p
r = prim(indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)],
transpose=transpose,
shape=shape)
return r[0] * weight


def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose):
indices = as_jax(indices)
indptr = as_jax(indptr)
matrix = as_jax(matrix)
weight = jnp.atleast_1d(weight)
return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose)


@partial(jax.jit, static_argnums=(4, 5))
def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose):
r = 0
for i in range(ITERATION):
r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None)
r += taichi_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose)
return r


@partial(jax.jit, static_argnums=(4, 5))
def csrmm(weight, indices, indptr, matrix, shape, transpose):
r = 0
for i in range(ITERATION):
r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='jaxlib')
r += jaxlib_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose)
return r


Expand Down

0 comments on commit c92ace4

Please sign in to comment.