Skip to content

Commit

Permalink
[math] Rebase operator customization using MLIR registration interface (
Browse files Browse the repository at this point in the history
#618)

* [array] fix array dependency

* [operator] upgrade numba operator registration using MLIR interface

* [operator] upgrade numba operator registration using MLIR interface

* update

* [operator] raise error when using `CustomOpByNumba` on jax>=0.4.24

* [operator] taichi aot operator customization upgrade, need further changes

* [math] temporarily fix the error, need upgrades

* upgrades

* [math] Add new taichi call function for single result

* Update test_taichi_based.py

* [math] Fix bugs

* updates

---------

Co-authored-by: He Sichao <[email protected]>
  • Loading branch information
chaoming0625 and Routhleck authored Feb 15, 2024
1 parent 7cdf768 commit 8f4803e
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 243 deletions.
6 changes: 3 additions & 3 deletions brainpy/_src/math/event/_csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _f(ct, indices, indptr, events, *, transpose):
event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching')
event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract)
event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p))
xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation
ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values,
None, None, _batch_event_csr_matvec_jvp_events)
ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose
Expand Down Expand Up @@ -597,8 +597,8 @@ def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events,
event_csr_matvec_p = Primitive('event_csr_matvec')
event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract)
event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p))
xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation
# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation
ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None,
_event_csr_matvec_jvp_events_brainpylib)
ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib
Expand Down
71 changes: 49 additions & 22 deletions brainpy/_src/math/event/_info_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import numba
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.interpreters import batching
from jax.lib import xla_client

from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import register_op_with_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy._src.math.op_register.base import XLACustomOp
from brainpy.errors import GPUOperatorNotFound

ti = import_taichi()

__all__ = [
'info'
Expand All @@ -40,7 +41,7 @@ def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]:
events = as_jax(events)
if events.ndim != 1:
raise TypeError('Only support 1D boolean vector.')
return event_info_p.bind(events)
return event_info_p(events)


def _batch_event_info_abstract(events):
Expand All @@ -66,11 +67,26 @@ def _batch_event_info(outs, ins):
event_num[batch_idx] = num


@ti.kernel
def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2),
event_ids: ti.types.ndarray(ndim=2),
event_num: ti.types.ndarray(ndim=1)):
for i, j in ti.grouped(ti.ndrange(event_ids.shape)):
event_ids[i, j] = -1
for batch_idx in range(event_ids.shape[0]):
num = 0
for i in range(event_ids.shape[1]):
if events[batch_idx, i]:
event_ids[batch_idx, num] = i
num += 1
event_num[batch_idx] = num


def _batch_event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
shape = arg.shape
arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2]))
event_ids, event_num = batch_event_info_p.bind(arg)
event_ids, event_num = batch_event_info_p(arg)
return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])),
(0, 0))

Expand Down Expand Up @@ -121,17 +137,15 @@ def _event_info_gpu_translation(c, events):
)


batch_event_info_p = register_op_with_numba(
op_name='event_info',
cpu_func=_batch_event_info,
out_shapes=_batch_event_info_abstract,
gpu_func_translation=_event_info_gpu_translation,
multiple_results=True
batch_event_info_p = XLACustomOp(
name='batched_event_info',
cpu_kernel=_batch_event_info_taichi,
outs=_batch_event_info_abstract,
)
batching.primitive_batchers[batch_event_info_p] = _batch_event_info_batching_rule
batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule)


def _event_info_abstract(events):
def _event_info_abstract(events, **kwargs):
assert events.ndim == 1
# assert events.dtype == jnp.bool_
event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape)
Expand All @@ -140,7 +154,7 @@ def _event_info_abstract(events):


# TODO: first parallel evaluate the sub-sections, then serially event the sub-results.
@numba.njit(fastmath=True)
@numba.jit(fastmath=True)
def _event_info(outs, ins):
event_ids, event_num = outs
event_num.fill(0)
Expand All @@ -154,16 +168,29 @@ def _event_info(outs, ins):
event_num[0] = num


@ti.kernel
def _event_info_taichi(events: ti.types.ndarray(ndim=1),
event_ids: ti.types.ndarray(ndim=1),
event_num: ti.types.ndarray(ndim=1)):
for i in range(event_ids.shape[0]):
event_ids[i] = -1
num = 0
for i in range(event_ids.shape[0]):
if events[i]:
event_ids[num] = i
num += 1
event_num[0] = num


def _event_info_batching_rule(args, axes):
arg = jnp.moveaxis(args[0], axes[0], 0)
return (batch_event_info_p.bind(arg), (0, 0))
return (batch_event_info_p(arg), (0, 0))


event_info_p = register_op_with_numba(
op_name='event_info',
cpu_func=_event_info,
out_shapes=_event_info_abstract,
gpu_func_translation=_event_info_gpu_translation,
multiple_results=True
event_info_p = XLACustomOp(
name='event_info',
cpu_kernel=_event_info_taichi,
outs=_event_info_abstract,
# gpu_func_translation=_event_info_gpu_translation,
)
batching.primitive_batchers[event_info_p] = _event_info_batching_rule
event_info_p.def_batching_rule(_event_info_batching_rule)
12 changes: 6 additions & 6 deletions brainpy/_src/math/jitconn/_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ def _event_matvec_prob_homo_transpose(
event_mv_prob_homo_p.multiple_results = True
event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract)
event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p))
xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation
ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp
ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose
register_general_batching(event_mv_prob_homo_p)
Expand Down Expand Up @@ -529,8 +529,8 @@ def _event_matvec_prob_uniform_transpose(
event_mv_prob_uniform_p.multiple_results = True
event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract)
event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p))
xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation
register_general_batching(event_mv_prob_uniform_p)
ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp
ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose
Expand Down Expand Up @@ -723,8 +723,8 @@ def _event_matvec_prob_normal_transpose(
event_mv_prob_normal_p.multiple_results = True
event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract)
event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p))
xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation
# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation
register_general_batching(event_mv_prob_normal_p)
ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp
ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose
Expand Down
12 changes: 6 additions & 6 deletions brainpy/_src/math/jitconn/_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def _matvec_prob_homo_transpose(
mv_prob_homo_p.multiple_results = True
mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract)
mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p))
xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation
register_general_batching(mv_prob_homo_p)
ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp
ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose
Expand Down Expand Up @@ -823,8 +823,8 @@ def _matvec_prob_uniform_transpose(
mv_prob_uniform_p.multiple_results = True
mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract)
mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p))
xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation
register_general_batching(mv_prob_uniform_p)
ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp
ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose
Expand Down Expand Up @@ -1009,8 +1009,8 @@ def _matvec_prob_normal_transpose(
mv_prob_normal_p.multiple_results = True
mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract)
mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p))
xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation
# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation
register_general_batching(mv_prob_normal_p)
ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp
ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose
Expand Down
20 changes: 15 additions & 5 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class

import brainpy.math
from brainpy.errors import MathError

bm = None


__all__ = [
'Array', 'ndarray', 'JaxArray', # alias of Array
'ShardedArray',
Expand Down Expand Up @@ -1039,7 +1041,9 @@ def __jax_array__(self):

def as_variable(self):
"""As an instance of Variable."""
return brainpy.math.Variable(self)
global bm
if bm is None: from brainpy import math as bm
return bm.Variable(self)

def __format__(self, specification):
return self.value.__format__(specification)
Expand Down Expand Up @@ -1473,7 +1477,9 @@ def fill_(self, value):
return self

def uniform_(self, low=0., high=1.):
self.value = brainpy.math.random.uniform(low, high, self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.uniform(low, high, self.shape)
return self

def log_normal_(self, mean=1, std=2):
Expand All @@ -1489,14 +1495,18 @@ def log_normal_(self, mean=1, std=2):
mean: the mean value.
std: the standard deviation.
"""
self.value = brainpy.math.random.lognormal(mean, std, self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.lognormal(mean, std, self.shape)
return self

def normal_(self, ):
"""
Fills self tensor with elements samples from the normal distribution parameterized by mean and std.
"""
self.value = brainpy.math.random.randn(*self.shape)
global bm
if bm is None: from brainpy import math as bm
self.value = bm.random.randn(*self.shape)
return self

def cuda(self):
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def _get_for_loop_transform(
unroll: int,
unroll_kwargs: tools.DotDict
):
@functools.wraps(body_fun)
def fun2scan(carry, x):
for k in dyn_vars.keys():
dyn_vars[k]._value = carry[k]
Expand Down Expand Up @@ -912,6 +913,7 @@ def for_loop(
dyn_vars[key]._value = dyn_vals[key]
if progress_bar:
bar.close()
del dyn_vals, dyn_vars
return out_vals


Expand Down
35 changes: 16 additions & 19 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
# if jax.__version__ >= '0.4.16':
# from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
# else:
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
clean_caches)

if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp


__all__ = [
'XLACustomOp',
]
Expand Down Expand Up @@ -64,8 +64,8 @@ class XLACustomOp(BrainPyObject):
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
Args:
Expand All @@ -74,7 +74,7 @@ class XLACustomOp(BrainPyObject):
batching_translation: Callable. The batching translation rule of JAX.
jvp_translation: Callable. The JVP translation rule of JAX.
transpose_translation: Callable. The transpose translation rule of JAX.
outs: optional, sequence of `ShapeDtype`. The output information.
outs: optional. The output information.
name: str. The primitive name.
"""

Expand All @@ -85,7 +85,7 @@ def __init__(
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
outs: Optional[Sequence[ShapeDtype]] = None,
outs: Optional[Callable] = None,
name: str = None,
):
super().__init__(name)
Expand All @@ -99,8 +99,6 @@ def __init__(
self.primitive.multiple_results = True

# abstract evaluation
if outs is not None:
outs = tuple([_transform_to_shapedarray(o) for o in outs])
self.outs = outs
self.primitive.def_abstract_eval(_abstract_eval)
self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
Expand Down Expand Up @@ -139,10 +137,11 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation


def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
if self.outs is None:
raise ValueError('The output information is not defined.')
outs = self.outs(*ins, **kwargs)
assert outs is not None
outs = tuple([_transform_to_shapedarray(o) for o in outs])
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
Expand Down Expand Up @@ -227,5 +226,3 @@ def _transform_to_array(a):

def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


Loading

0 comments on commit 8f4803e

Please sign in to comment.