Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Rebase operator customization using MLIR registration interface #618

Merged
merged 12 commits into from
Feb 15, 2024
6 changes: 3 additions & 3 deletions brainpy/_src/math/event/_csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _f(ct, indices, indptr, events, *, transpose):
event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching')
event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract)
event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p))
xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values,
None, None, _batch_event_csr_matvec_jvp_events)
ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose
Expand Down Expand Up @@ -597,8 +597,8 @@ def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events,
event_csr_matvec_p = Primitive('event_csr_matvec')
event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract)
event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p))
xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None,
_event_csr_matvec_jvp_events_brainpylib)
ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib
Expand Down
71 changes: 49 additions & 22 deletions brainpy/_src/math/event/_info_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import numba
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.interpreters import batching
from jax.lib import xla_client

from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import register_op_with_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy._src.math.op_register.base import XLACustomOp
from brainpy.errors import GPUOperatorNotFound

ti = import_taichi()

__all__ = [
'info'
Expand All @@ -40,7 +41,7 @@ def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]:
events = as_jax(events)
if events.ndim != 1:
raise TypeError('Only support 1D boolean vector.')
return event_info_p.bind(events)
return event_info_p(events)


def _batch_event_info_abstract(events):
Expand All @@ -66,11 +67,26 @@ def _batch_event_info(outs, ins):
event_num[batch_idx] = num


@ti.kernel
def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2),
event_ids: ti.types.ndarray(ndim=2),
event_num: ti.types.ndarray(ndim=1)):
for i, j in ti.grouped(ti.ndrange(event_ids.shape)):
event_ids[i, j] = -1
for batch_idx in range(event_ids.shape[0]):
num = 0
for i in range(event_ids.shape[1]):
if events[batch_idx, i]:
event_ids[batch_idx, num] = i
num += 1
event_num[batch_idx] = num


def _batch_event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
shape = arg.shape
arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2]))
event_ids, event_num = batch_event_info_p.bind(arg)
event_ids, event_num = batch_event_info_p(arg)
return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])),
(0, 0))

Expand Down Expand Up @@ -121,17 +137,15 @@ def _event_info_gpu_translation(c, events):
)


batch_event_info_p = register_op_with_numba(
op_name='event_info',
cpu_func=_batch_event_info,
out_shapes=_batch_event_info_abstract,
gpu_func_translation=_event_info_gpu_translation,
multiple_results=True
batch_event_info_p = XLACustomOp(
name='batched_event_info',
cpu_kernel=_batch_event_info_taichi,
outs=_batch_event_info_abstract,
)
batching.primitive_batchers[batch_event_info_p] = _batch_event_info_batching_rule
batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule)


def _event_info_abstract(events):
def _event_info_abstract(events, **kwargs):
assert events.ndim == 1
# assert events.dtype == jnp.bool_
event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape)
Expand All @@ -140,7 +154,7 @@ def _event_info_abstract(events):


# TODO: first parallel evaluate the sub-sections, then serially event the sub-results.
@numba.njit(fastmath=True)
@numba.jit(fastmath=True)
def _event_info(outs, ins):
event_ids, event_num = outs
event_num.fill(0)
Expand All @@ -154,16 +168,29 @@ def _event_info(outs, ins):
event_num[0] = num


@ti.kernel
def _event_info_taichi(events: ti.types.ndarray(ndim=1),
event_ids: ti.types.ndarray(ndim=1),
event_num: ti.types.ndarray(ndim=1)):
for i in range(event_ids.shape[0]):
event_ids[i] = -1
num = 0
for i in range(event_ids.shape[0]):
if events[i]:
event_ids[num] = i
num += 1
event_num[0] = num


def _event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
return (batch_event_info_p.bind(arg), (0, 0))
return (batch_event_info_p(arg), (0, 0))


event_info_p = register_op_with_numba(
op_name='event_info',
cpu_func=_event_info,
out_shapes=_event_info_abstract,
gpu_func_translation=_event_info_gpu_translation,
multiple_results=True
event_info_p = XLACustomOp(
name='event_info',
cpu_kernel=_event_info_taichi,
outs=_event_info_abstract,
# gpu_func_translation=_event_info_gpu_translation,
)
batching.primitive_batchers[event_info_p] = _event_info_batching_rule
event_info_p.def_batching_rule(_event_info_batching_rule)
12 changes: 6 additions & 6 deletions brainpy/_src/math/jitconn/_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ def _event_matvec_prob_homo_transpose(
event_mv_prob_homo_p.multiple_results = True
event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract)
event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p))
xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp
ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose
register_general_batching(event_mv_prob_homo_p)
Expand Down Expand Up @@ -529,8 +529,8 @@ def _event_matvec_prob_uniform_transpose(
event_mv_prob_uniform_p.multiple_results = True
event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract)
event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p))
xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
register_general_batching(event_mv_prob_uniform_p)
ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp
ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose
Expand Down Expand Up @@ -723,8 +723,8 @@ def _event_matvec_prob_normal_transpose(
event_mv_prob_normal_p.multiple_results = True
event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract)
event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p))
xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
register_general_batching(event_mv_prob_normal_p)
ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp
ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose
Expand Down
12 changes: 6 additions & 6 deletions brainpy/_src/math/jitconn/_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def _matvec_prob_homo_transpose(
mv_prob_homo_p.multiple_results = True
mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract)
mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p))
xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
register_general_batching(mv_prob_homo_p)
ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp
ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose
Expand Down Expand Up @@ -823,8 +823,8 @@ def _matvec_prob_uniform_transpose(
mv_prob_uniform_p.multiple_results = True
mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract)
mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p))
xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
register_general_batching(mv_prob_uniform_p)
ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp
ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose
Expand Down Expand Up @@ -1009,8 +1009,8 @@ def _matvec_prob_normal_transpose(
mv_prob_normal_p.multiple_results = True
mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract)
mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p))
xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
register_general_batching(mv_prob_normal_p)
ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp
ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose
Expand Down
20 changes: 15 additions & 5 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class

import brainpy.math
from brainpy.errors import MathError

bm = None


__all__ = [
'Array', 'ndarray', 'JaxArray', # alias of Array
'ShardedArray',
Expand Down Expand Up @@ -1039,7 +1041,9 @@ def __jax_array__(self):

def as_variable(self):
"""As an instance of Variable."""
return brainpy.math.Variable(self)
global bm
if bm is None: from brainpy import math as bm
return bm.Variable(self)

def __format__(self, specification):
return self.value.__format__(specification)
Expand Down Expand Up @@ -1473,7 +1477,9 @@ def fill_(self, value):
return self

def uniform_(self, low=0., high=1.):
self.value = brainpy.math.random.uniform(low, high, self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.uniform(low, high, self.shape)
return self

def log_normal_(self, mean=1, std=2):
Expand All @@ -1489,14 +1495,18 @@ def log_normal_(self, mean=1, std=2):
mean: the mean value.
std: the standard deviation.
"""
self.value = brainpy.math.random.lognormal(mean, std, self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.lognormal(mean, std, self.shape)
return self

def normal_(self, ):
"""
Fills self tensor with elements samples from the normal distribution parameterized by mean and std.
"""
self.value = brainpy.math.random.randn(*self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.randn(*self.shape)
return self

def cuda(self):
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def _get_for_loop_transform(
unroll: int,
unroll_kwargs: tools.DotDict
):
@functools.wraps(body_fun)
def fun2scan(carry, x):
for k in dyn_vars.keys():
dyn_vars[k]._value = carry[k]
Expand Down Expand Up @@ -912,6 +913,7 @@ def for_loop(
dyn_vars[key]._value = dyn_vals[key]
if progress_bar:
bar.close()
del dyn_vals, dyn_vars
return out_vals


Expand Down
35 changes: 16 additions & 19 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
# if jax.__version__ >= '0.4.16':
# from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
# else:
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
clean_caches)

if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp


__all__ = [
'XLACustomOp',
]
Expand Down Expand Up @@ -64,8 +64,8 @@ class XLACustomOp(BrainPyObject):
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))

Args:
Expand All @@ -74,7 +74,7 @@ class XLACustomOp(BrainPyObject):
batching_translation: Callable. The batching translation rule of JAX.
jvp_translation: Callable. The JVP translation rule of JAX.
transpose_translation: Callable. The transpose translation rule of JAX.
outs: optional, sequence of `ShapeDtype`. The output information.
outs: optional. The output information.
name: str. The primitive name.
"""

Expand All @@ -85,7 +85,7 @@ def __init__(
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
outs: Optional[Sequence[ShapeDtype]] = None,
outs: Optional[Callable] = None,
name: str = None,
):
super().__init__(name)
Expand All @@ -99,8 +99,6 @@ def __init__(
self.primitive.multiple_results = True

# abstract evaluation
if outs is not None:
outs = tuple([_transform_to_shapedarray(o) for o in outs])
self.outs = outs
self.primitive.def_abstract_eval(_abstract_eval)
self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
Expand Down Expand Up @@ -139,10 +137,11 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation


def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
if self.outs is None:
raise ValueError('The output information is not defined.')
outs = self.outs(*ins, **kwargs)
assert outs is not None
outs = tuple([_transform_to_shapedarray(o) for o in outs])
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
Expand Down Expand Up @@ -227,5 +226,3 @@ def _transform_to_array(a):

def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


Loading
Loading