Skip to content

Commit

Permalink
Update new transpose taichi kernels of csrmm and event csrmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 7, 2024
1 parent 861b340 commit fff64db
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 323 deletions.
270 changes: 59 additions & 211 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,15 @@ def raw_event_csrmm_taichi(
else:
prim = _event_csr_matmat_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csr_matmat_transpose_homo_p
else:
prim = _event_csr_matmat_transpose_heter_p
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
if matrix.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csr_matmat_bool_homo_p
else:
prim = _event_csr_matmat_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csr_matmat_homo_p
else:
prim = _event_csr_matmat_heter_p
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
return prim(data,
indices,
indptr,
Expand All @@ -117,50 +111,42 @@ def raw_event_csrmm_taichi(
shape=shape)


# CPU kernels
# taichi kernels

@ti.kernel
def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val * matrix[row_j, col_i]
out[row_k, col_i] = r
out[row_k, col_i] += values[j] * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val * matrix[row_j, col_i]
out[row_k, col_i] = r
out[row_k, col_i] += values[j] * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
Expand All @@ -170,11 +156,11 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
Expand All @@ -184,179 +170,41 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r * value


@ti.kernel
def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r * value


@ti.kernel
def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


@ti.kernel
def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


# GPU kernels

@ti.kernel
def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j] * matrix[row_j, col_i]
r += val
out[row_k, col_i] = r


@ti.kernel
def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j] * matrix[row_j, col_i]
r += val
out[row_k, col_i] = r


@ti.kernel
def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r * value
out[row_k, col_i] += value * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r * value
out[row_k, col_i] += value * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
Expand All @@ -367,11 +215,11 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
Expand Down Expand Up @@ -421,33 +269,33 @@ def _define_op(cpu_kernel, gpu_kernel):


# transpose heter
_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter_cpu,
gpu_kernel=_event_csr_matmat_transpose_heter_gpu)
_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter,
gpu_kernel=_event_csr_matmat_transpose_heter)

# no transpose heter
_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter_cpu,
gpu_kernel=_event_csr_matmat_heter_gpu)
_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter,
gpu_kernel=_event_csr_matmat_heter)

# transpose homo
_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo_cpu,
gpu_kernel=_event_csr_matmat_transpose_homo_gpu)
_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo,
gpu_kernel=_event_csr_matmat_transpose_homo)

# no transpose homo
_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo_cpu,
gpu_kernel=_event_csr_matmat_homo_gpu)
_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo,
gpu_kernel=_event_csr_matmat_homo)

# bool transpose heter
_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter_cpu,
gpu_kernel=_event_csr_matmat_transpose_bool_heter_gpu)
_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter,
gpu_kernel=_event_csr_matmat_transpose_bool_heter)

# bool no transpose heter
_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter_cpu,
gpu_kernel=_event_csr_matmat_bool_heter_gpu)
_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter,
gpu_kernel=_event_csr_matmat_bool_heter)

# bool transpose homo
_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo_cpu,
gpu_kernel=_event_csr_matmat_transpose_bool_homo_gpu)
_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo,
gpu_kernel=_event_csr_matmat_transpose_bool_homo)

# bool no transpose homo
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo_cpu,
gpu_kernel=_event_csr_matmat_bool_homo_gpu)
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)
Loading

0 comments on commit fff64db

Please sign in to comment.