Skip to content

Commit

Permalink
format the code
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 29, 2024
1 parent 12d045d commit e1f4005
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 29 deletions.
23 changes: 13 additions & 10 deletions brainpy/_src/math/event/_csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
"""


from functools import partial
from typing import Union, Tuple

import brainpy.math as bm
import jax
import jax.numpy as jnp
import numba
Expand All @@ -23,6 +21,7 @@
from jax.interpreters import ad, xla
from jax.lib import xla_client

from brainpy._src.dependency_check import (import_brainpylib_gpu_ops)
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
Expand All @@ -31,7 +30,6 @@
from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv
from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo
from brainpy._src.dependency_check import (import_brainpylib_gpu_ops)
from brainpy.errors import GPUOperatorNotFound

__all__ = [
Expand Down Expand Up @@ -159,6 +157,7 @@ def csrmv_brainpylib(
# computing
return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose)


# ----------------------------------------------------------
# event csr matvec
# ----------------------------------------------------------
Expand Down Expand Up @@ -600,9 +599,12 @@ def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events,
event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p))
xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, _event_csr_matvec_jvp_events_brainpylib)
ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None,
_event_csr_matvec_jvp_events_brainpylib)
ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib
register_general_batching(event_csr_matvec_p)


# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule


Expand Down Expand Up @@ -688,6 +690,7 @@ def csrmv_taichi(

return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]


# -------------
# CPU operators
# -------------
Expand Down Expand Up @@ -958,7 +961,7 @@ def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]]:
r += values[j]
j += 32
out[row_i] += r # TODO: warp-level primitive
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
Expand All @@ -977,7 +980,8 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]] != 0.:
r += values[j]
j += 32
out[row_i] += r # TODO: warp-level primitive
out[row_i] += r # TODO: warp-level primitive


def raw_csrmv_taichi(
data: Union[float, jax.Array],
Expand Down Expand Up @@ -1020,6 +1024,7 @@ def raw_csrmv_taichi(
transpose=transpose,
shape=shape)


def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)

Expand Down Expand Up @@ -1047,6 +1052,8 @@ def _event_csr_matvec_transpose_taichi(
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
return ct_values, indices, indptr, events


def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi)
Expand Down Expand Up @@ -1080,7 +1087,3 @@ def _define_op(cpu_kernel, gpu_kernel):

# not transpose heter
_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu)




28 changes: 19 additions & 9 deletions brainpy/_src/math/jitconn/_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Tuple, Optional

import brainpy.math as bm
import jax
import numpy as np
from jax import numpy as jnp, dtypes
Expand All @@ -19,12 +18,12 @@
mv_prob_homo,
mv_prob_uniform,
mv_prob_normal,
_general_checking,
raw_mv_prob_homo,
raw_mv_prob_uniform,
_general_checking,
raw_mv_prob_homo,
raw_mv_prob_uniform,
raw_mv_prob_normal,
_mv_prob_homo_transpose,
_mv_prob_uniform_transpose,
_mv_prob_homo_transpose,
_mv_prob_uniform_transpose,
_mv_prob_normal_transpose,
_reverse)
from brainpy._src.math.ndarray import _get_dtype
Expand All @@ -51,7 +50,9 @@ def event_mv_prob_homo(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__

Expand All @@ -67,7 +68,9 @@ def event_mv_prob_uniform(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__

Expand All @@ -83,7 +86,9 @@ def event_mv_prob_normal(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


### BRAINPYLIB ###

Expand Down Expand Up @@ -180,6 +185,7 @@ def event_mv_prob_normal_brainpylib(
transpose=transpose,
outdim_parallel=outdim_parallel)[0]


event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__


Expand Down Expand Up @@ -872,6 +878,7 @@ def event_mv_prob_uniform_taichi(
return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def event_mv_prob_normal_taichi(
events: jax.Array,
w_mu: float,
Expand Down Expand Up @@ -947,6 +954,7 @@ def event_mv_prob_normal_taichi(
return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]


# -------------
# CPU function
# -------------
Expand Down Expand Up @@ -1075,9 +1083,11 @@ def _event_mv_prob_homo_outdim_parallel_bool_gpu(
i_col += inc
out[i_row] += r # TODO: warp-level reduction


def _reverse(shape):
return shape[::-1]


# -------------
# CPU function
# -------------
Expand Down
15 changes: 9 additions & 6 deletions brainpy/_src/math/jitconn/_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial
from typing import Tuple, Optional, Union

import brainpy.math as bm
import jax
import numpy as np
from jax import numpy as jnp, dtypes
Expand Down Expand Up @@ -86,8 +85,8 @@ def mv_prob_homo(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)

return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


def mv_prob_uniform(
Expand Down Expand Up @@ -151,7 +150,8 @@ def mv_prob_uniform(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


def mv_prob_normal(
Expand Down Expand Up @@ -215,7 +215,8 @@ def mv_prob_normal(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose,
outdim_parallel=outdim_parallel)


### BRAINYPLIB ###
Expand Down Expand Up @@ -456,7 +457,6 @@ def mv_prob_normal_brainpylib(
outdim_parallel=outdim_parallel)[0]



def _matvec_prob_homo_abstract(
vector, weight, clen, seed, *, shape, transpose, outdim_parallel
):
Expand Down Expand Up @@ -1095,6 +1095,7 @@ def mv_prob_homo_taichi(
return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def mv_prob_uniform_taichi(
vector: jax.Array,
w_low: float,
Expand Down Expand Up @@ -1170,6 +1171,7 @@ def mv_prob_uniform_taichi(
return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def mv_prob_normal_taichi(
vector: jax.Array,
w_mu: float,
Expand Down Expand Up @@ -1245,6 +1247,7 @@ def mv_prob_normal_taichi(
return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def _reverse(shape):
return shape[::-1]

Expand Down
11 changes: 7 additions & 4 deletions brainpy/_src/math/sparse/_csr_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def csrmv(
else:
return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method)


### BRAINPYLIB ###

def csrmv_brainpylib(
data: Union[float, jnp.ndarray, Array],
indices: Union[jnp.ndarray, Array],
Expand Down Expand Up @@ -164,6 +165,7 @@ def csrmv_brainpylib(
else:
raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.')


def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose):
if data.dtype not in [jnp.float32, jnp.float64]:
raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.')
Expand Down Expand Up @@ -587,7 +589,7 @@ def csrmv_taichi(
# if the shape of indices is (0,), then we return a zero vector
if indices.shape[0] == 0:
return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)

return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0]


Expand Down Expand Up @@ -755,6 +757,7 @@ def _sparse_csr_matvec_transpose(

return ct_data, indices, indptr, vector


def raw_csrmv_taichi(
data: Union[float, jnp.ndarray, Array],
indices: Union[jnp.ndarray, Array],
Expand Down Expand Up @@ -783,7 +786,7 @@ def raw_csrmv_taichi(
outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
transpose=transpose,
shape=shape)


def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
Expand All @@ -806,4 +809,4 @@ def _define_op(cpu_kernel, gpu_kernel):

# no transpose heter
_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu,
gpu_kernel=_sparse_csr_matvec_heter_gpu)
gpu_kernel=_sparse_csr_matvec_heter_gpu)

0 comments on commit e1f4005

Please sign in to comment.