From 2171061dad2a8b7355022fb0ccb7ce854200c4f4 Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Sat, 23 Nov 2024 17:22:55 +0800
Subject: [PATCH] Remove op register

---
 brainpy/_src/math/__init__.py                 |   1 -
 brainpy/_src/math/jitconn/matvec.py           |   6 +-
 brainpy/_src/math/op_register/__init__.py     |   7 -
 brainpy/_src/math/op_register/ad_support.py   |  56 ---
 brainpy/_src/math/op_register/base.py         | 224 ---------
 brainpy/_src/math/op_register/cupy_based.py   | 279 -----------
 .../op_register/numba_approach/__init__.py    | 295 -----------
 .../numba_approach/cpu_translation.py         | 228 ---------
 .../tests/test_numba_approach.py              |  48 --
 brainpy/_src/math/op_register/numba_based.py  | 181 -------
 .../math/op_register/tests/test_ad_support.py | 143 ------
 .../math/op_register/tests/test_cupy_based.py |  79 ---
 .../op_register/tests/test_numba_based.py     |  55 ---
 brainpy/_src/math/op_register/utils.py        |  42 --
 brainpy/_src/math/sparse/__init__.py          |   2 -
 brainpy/_src/math/sparse/bsr_mm.py            | 462 ------------------
 brainpy/_src/math/sparse/bsr_mv.py            | 210 --------
 brainpy/_src/math/sparse/utils.py             |  32 +-
 brainpy/math/__init__.py                      |   1 -
 brainpy/math/op_register.py                   |  10 -
 20 files changed, 30 insertions(+), 2331 deletions(-)
 delete mode 100644 brainpy/_src/math/op_register/__init__.py
 delete mode 100644 brainpy/_src/math/op_register/ad_support.py
 delete mode 100644 brainpy/_src/math/op_register/base.py
 delete mode 100644 brainpy/_src/math/op_register/cupy_based.py
 delete mode 100644 brainpy/_src/math/op_register/numba_approach/__init__.py
 delete mode 100644 brainpy/_src/math/op_register/numba_approach/cpu_translation.py
 delete mode 100644 brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py
 delete mode 100644 brainpy/_src/math/op_register/numba_based.py
 delete mode 100644 brainpy/_src/math/op_register/tests/test_ad_support.py
 delete mode 100644 brainpy/_src/math/op_register/tests/test_cupy_based.py
 delete mode 100644 brainpy/_src/math/op_register/tests/test_numba_based.py
 delete mode 100644 brainpy/_src/math/op_register/utils.py
 delete mode 100644 brainpy/_src/math/sparse/bsr_mm.py
 delete mode 100644 brainpy/_src/math/sparse/bsr_mv.py
 delete mode 100644 brainpy/math/op_register.py

diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py
index 011598837..a28ba7d84 100644
--- a/brainpy/_src/math/__init__.py
+++ b/brainpy/_src/math/__init__.py
@@ -47,7 +47,6 @@
 from . import random, linalg, fft
 
 # operators
-from .op_register import *
 from .pre_syn_post import *
 from . import surrogate, event, sparse, jitconn
 
diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py
index 4d4dd25a5..be4b19d19 100644
--- a/brainpy/_src/math/jitconn/matvec.py
+++ b/brainpy/_src/math/jitconn/matvec.py
@@ -4,14 +4,12 @@
 
 import jax
 import numpy as np
-from jax import numpy as jnp
-
+from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
 from brainpy._src.math import defaults
 from brainpy._src.math.interoperability import as_jax
 from brainpy._src.math.ndarray import Array
-from brainpy._src.math.op_register import XLACustomOp
 from brainpy.errors import PackageMissingError
-from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
+from jax import numpy as jnp
 
 bti = import_braintaichi(error_if_not_found=False)
 
diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py
deleted file mode 100644
index 19160708c..000000000
--- a/brainpy/_src/math/op_register/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .numba_approach import (CustomOpByNumba,
-                             register_op_with_numba_xla,
-                             compile_cpu_signature_with_numba)
-from .base import XLACustomOp
-from .utils import register_general_batching
-from .base import XLACustomOp
-from .utils import register_general_batching
diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py
deleted file mode 100644
index 54a3c9be2..000000000
--- a/brainpy/_src/math/op_register/ad_support.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import functools
-from functools import partial
-
-from jax import tree_util
-from jax.core import Primitive
-from jax.interpreters import ad
-
-__all__ = [
-  'defjvp',
-]
-
-
-def defjvp(primitive, *jvp_rules):
-  """Define JVP rules for any JAX primitive.
-
-  This function is similar to ``jax.interpreters.ad.defjvp``.
-  However, the JAX one only supports primitive with ``multiple_results=False``.
-  ``brainpy.math.defjvp`` enables to define the independent JVP rule for
-  each input parameter no matter ``multiple_results=False/True``.
-
-  For examples, please see ``test_ad_support.py``.
-
-  Args:
-    primitive: Primitive, XLACustomOp.
-    *jvp_rules: The JVP translation rule for each primal.
-  """
-  assert isinstance(primitive, Primitive)
-  if primitive.multiple_results:
-    ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
-  else:
-    ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
-
-
-def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
-  assert primitive.multiple_results
-  val_out = tuple(primitive.bind(*primals, **params))
-  tree = tree_util.tree_structure(val_out)
-  tangents_out = []
-  for rule, t in zip(jvp_rules, tangents):
-    if rule is not None and type(t) is not ad.Zero:
-      r = tuple(rule(t, *primals, **params))
-      tangents_out.append(r)
-      assert tree_util.tree_structure(r) == tree
-  try:
-    return val_out, functools.reduce(_add_tangents,
-                                   tangents_out,
-                                   tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out))
-  except:
-    return val_out, functools.reduce(_add_tangents,
-                                     tangents_out,
-                                     tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
-
-
-def _add_tangents(xs, ys):
-  return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
-
diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py
deleted file mode 100644
index a6dd5a5b8..000000000
--- a/brainpy/_src/math/op_register/base.py
+++ /dev/null
@@ -1,224 +0,0 @@
-from functools import partial
-from typing import Callable, Sequence, Tuple, Protocol, Optional, Union
-
-import jax
-import numpy as np
-from jax.interpreters import xla, batching, ad, mlir
-
-from brainpy._src.dependency_check import import_numba, import_cupy_jit
-from brainpy._src.math.ndarray import Array
-from brainpy._src.math.object_transform.base import BrainPyObject
-
-is_version_right = False
-if jax.__version__ >= '0.4.16':
-  from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
-  from braintaichi._primitive._mlir_translation_rule import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
-                                 register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
-  from .cupy_based import (
-    register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
-    register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
-  is_version_right = True
-
-from .utils import register_general_batching
-from brainpy._src.math.op_register.ad_support import defjvp
-
-numba = import_numba(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-
-__all__ = [
-  'XLACustomOp',
-]
-
-
-class ShapeDtype(Protocol):
-
-  @property
-  def shape(self) -> Tuple[int, ...]:
-    ...
-
-  @property
-  def dtype(self) -> np.dtype:
-    ...
-
-
-class XLACustomOp(BrainPyObject):
-  """Creating a XLA custom call operator.
-
-  For more information, please refer to the tutorials above:
-  Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html
-  Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html
-  CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html
-
-  Args:
-    cpu_kernel: Callable. The function defines the computation on CPU backend.
-    gpu_kernel: Callable. The function defines the computation on GPU backend.
-    batching_translation: Callable. The batching translation rule of JAX.
-    jvp_translation: Callable. The JVP translation rule of JAX.
-    transpose_translation: Callable. The transpose translation rule of JAX.
-    outs: optional. The output information.
-    name: str. The primitive name.
-  """
-
-  def __init__(
-      self,
-      cpu_kernel: Callable = None,
-      gpu_kernel: Union[Callable, str] = None,
-      batching_translation: Callable = None,
-      jvp_translation: Callable = None,
-      transpose_translation: Callable = None,
-      outs: Optional[Callable] = None,
-      name: str = None,
-  ):
-    if not is_version_right:
-      raise ImportError('XLA Custom Op is only supported in JAX>=0.4.16')
-    super().__init__(name)
-
-    # set cpu_kernel and gpu_kernel
-    self.cpu_kernel = cpu_kernel
-    self.gpu_kernel = gpu_kernel
-
-    # primitive
-    self.primitive = jax.core.Primitive(self.name)
-    self.primitive.multiple_results = True
-
-    # abstract evaluation
-    self.outs = outs
-    self.primitive.def_abstract_eval(_abstract_eval)
-    self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
-
-    # cpu function
-    cpu_checked = False
-    if cpu_kernel is None:
-      cpu_checked = True
-    if numba is not None:  # numba
-      from numba.core.dispatcher import Dispatcher
-      if isinstance(cpu_kernel, Dispatcher):
-        register_numba_cpu_translation_rule(self.primitive, cpu_kernel)
-        cpu_checked = True
-    if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel:  # taichi
-      register_taichi_cpu_translation_rule(self.primitive, cpu_kernel)
-      cpu_checked = True
-    if not cpu_checked:
-      raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. '
-                       f'But we got {cpu_kernel}')
-
-    # gpu function
-    gpu_checked = False
-    if gpu_kernel is None:
-      gpu_checked = True
-    elif hasattr(gpu_kernel, 'kernel'):  # cupy RawModule
-      register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
-      gpu_checked = True
-    elif hasattr(gpu_kernel, '_mode'):  # cupy JIT Kernel
-      register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel)
-      gpu_checked = True
-    elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel:  # taichi
-      register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
-      gpu_checked = True
-    if not gpu_checked:
-      raise ValueError(
-        f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')
-
-    # batching rule
-    if batching_translation is None:
-      register_general_batching(self.primitive)
-    else:
-      batching.primitive_batchers[self.primitive] = batching_translation
-
-    # jvp rule
-    if jvp_translation is not None:
-      ad.primitive_jvps[self.primitive] = jvp_translation
-
-    # transpose rule
-    if transpose_translation is not None:
-      ad.primitive_transposes[self.primitive] = transpose_translation
-
-  def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
-    if outs is None:
-      if self.outs is None:
-        raise ValueError('The output information is not defined.')
-      outs = self.outs(*ins, **kwargs)
-    assert outs is not None
-    outs = tuple([_transform_to_shapedarray(o) for o in outs])
-    ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
-    return self.primitive.bind(*ins, outs=outs, **kwargs)
-
-  def def_abstract_eval(self, fun):
-    """Define the abstract evaluation function.
-
-    Args:
-      fun: The abstract evaluation function.
-    """
-    self.primitive.def_abstract_eval(fun)
-
-  def def_batching_rule(self, fun):
-    """Define the batching rule.
-
-    Args:
-      fun: The batching rule.
-    """
-    batching.primitive_batchers[self.primitive] = fun
-
-  def def_jvp_rule(self, fun):
-    """Define the JVP rule.
-
-    Args:
-      fun: The JVP rule.
-    """
-    ad.primitive_jvps[self.primitive] = fun
-
-  def defjvp(self, *jvp_rules):
-    """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results.
-
-    Args:
-      jvp_rules: The JVP rules.
-    """
-    defjvp(self.primitive, *jvp_rules)
-
-  def def_transpose_rule(self, fun):
-    """Define the transpose rule.
-
-    Args:
-      fun: The transpose rule.
-    """
-    ad.primitive_transposes[self.primitive] = fun
-
-  def def_xla_translation(self, platform, fun):
-    """Define the XLA translation rule.
-
-    Args:
-      platform: str. The computing platform.
-      fun: The XLA translation rule.
-    """
-    xla.backend_specific_translations[platform][self.primitive] = fun
-
-  def def_mlir_lowering(self, platform, fun):
-    """Define the MLIR lowering rule.
-
-    Args:
-      platform: str. The computing platform.
-      fun: The lowering rule.
-    """
-    mlir.register_lowering(self.primitive, fun, platform)
-
-
-def _abstract_eval(*args, **kwargs):
-  return [jax.core.ShapedArray(out_shape.shape, out_shape.dtype)
-          for out_shape in kwargs['outs']]
-
-
-def _is_bp_array(a):
-  return isinstance(a, Array)
-
-
-def _transform_to_array(a):
-  if isinstance(a, Array):
-    return a.value
-  elif isinstance(a, jax.Array):
-    return a
-  else:
-    return jax.numpy.asarray(a)
-
-
-def _transform_to_shapedarray(a):
-  return jax.core.ShapedArray(a.shape, a.dtype)
diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py
deleted file mode 100644
index ad6befecf..000000000
--- a/brainpy/_src/math/op_register/cupy_based.py
+++ /dev/null
@@ -1,279 +0,0 @@
-from functools import partial, reduce
-from typing import List, Tuple
-
-import jax
-import numpy as np
-from jax.interpreters import xla, mlir
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import (import_cupy,
-                                           import_cupy_jit,
-                                           import_brainpylib_gpu_ops)
-from brainpy._src.math.op_register.utils import _shape_to_layout
-from brainpy.errors import PackageMissingError
-
-cp = import_cupy(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-
-# convert type to number
-type_number_map = {
-  int: 0,
-  float: 1,
-  bool: 2,
-  np.dtype('int32'): 0,
-  np.dtype('float32'): 1,
-  np.dtype('bool'): 2,
-  np.dtype('uint8'): 3,
-  np.dtype('uint16'): 4,
-  np.dtype('uint32'): 5,
-  np.dtype('uint64'): 6,
-  np.dtype('int8'): 7,
-  np.dtype('int16'): 8,
-  np.dtype('int64'): 9,
-  np.dtype('float16'): 10,
-  np.dtype('float64'): 11,
-}
-
-
-def _preprocess_kernel_call_gpu(
-    grid: Tuple[int],
-    block: Tuple[int],
-    func_ptr: int,
-    shared_mem: int,
-    *ins,
-    outs: List[jax.ShapeDtypeStruct],
-):
-  grid = (grid + (1, 1))[:3]
-  block = (block + (1, 1))[:3]
-  in_num = len(ins)
-  out_num = len(outs)
-  in_out_num = [in_num, out_num]
-
-  out_type_list = [0] * out_num
-  out_elem_count_list = [0] * out_num
-
-  for i, value in enumerate(outs):
-    out_type_list[i] = type_number_map[value.dtype]
-    out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape)
-
-  grid = ",".join(str(i) for i in grid)
-  block = ",".join(str(i) for i in block)
-  in_out_num_str = ",".join(str(i) for i in in_out_num)
-  out_type_list_str = ",".join(str(i) for i in out_type_list)
-  out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list)
-
-  opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' +
-            bytes(str(shared_mem), encoding='utf-8') + b';' +
-            bytes(in_out_num_str, encoding='utf-8') + b';' +
-            bytes(grid, encoding='utf-8') + b';' +
-            bytes(block, encoding='utf-8') + b';' +
-            bytes(out_type_list_str, encoding='utf-8') + b';' +
-            bytes(out_elem_count_list_str, encoding='utf-8') + b';')
-  return opaque
-
-
-def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
-  grid = kwargs.get('grid', None)
-  block = kwargs.get('block', None)
-  shared_mem = kwargs.get('shared_mem', 0)
-  if grid is None or block is None:
-    raise ValueError('The grid and block should be specified for the cupy kernel.')
-
-  # preprocess
-  import_brainpylib_gpu_ops()
-  # THE KEY:
-  # - using the kernel pointer at "kernel.kernel.ptr"
-  opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
-  # create custom call
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    b'cupy_kernel_call_gpu',
-    operands=ins,
-    operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
-    shape_with_layout=xla_client.Shape.tuple_shape(
-      [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
-       for value in kwargs['outs']]
-    ),
-    opaque=opaque,
-  )
-
-
-def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel):
-  xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel)
-
-
-def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
-  grid = kwargs.get('grid', None)
-  block = kwargs.get('block', None)
-  shared_mem = kwargs.get('shared_mem', 0)
-  if grid is None or block is None:
-    raise ValueError('The grid and block should be specified for the cupy kernel.')
-
-  # preprocess
-  import_brainpylib_gpu_ops()
-  opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
-  input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
-  result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
-  output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]
-
-  return custom_call(
-    call_target_name='cupy_kernel_call_gpu',
-    operands=ins,
-    operand_layouts=list(input_layouts),
-    result_layouts=list(output_layouts),
-    result_types=list(result_types),
-    backend_config=opaque,
-    has_side_effect=False,
-  ).results
-
-
-def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel):
-  if cp is None:
-    raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')
-
-  rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel)
-  mlir.register_lowering(primitive, rule, platform='gpu')
-
-
-def _to_cupy_array_or_scalar(dtype, ndim):
-  # THE KEY
-  # - using the cupy jit compiler to get the type
-  if ndim != 0:
-    t = cp_jit._cuda_types.CArray(dtype=dtype,
-                                  ndim=ndim,
-                                  is_c_contiguous=True,
-                                  index_32_bits=True)
-  else:
-    t = cp_jit._cuda_types.Scalar(dtype=dtype)
-  return t
-
-
-def _compile_kernel_xla(kernel, in_types):
-  # THE KEY
-  # - get the kernel function from the cache
-  device_id = cp.cuda.get_device_id()
-  kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))
-
-  if kern is None:
-    # THE KEY:
-    # - compile the kernel function
-    result = kernel._cached_codes.get(in_types)
-    if result is None:
-      result = cp_jit._compile.transpile(
-        kernel._func,
-        ['extern "C"', '__global__'],
-        'cuda',
-        in_types,
-        cp_jit._cuda_types.void,
-      )
-      kernel._cached_codes[in_types] = result
-    fname = result.func_name
-    enable_cg = result.enable_cooperative_groups
-    options = result.options
-    backend = result.backend
-    if backend == 'nvcc':
-      options += ('-DCUPY_JIT_NVCC',)
-    jitify = result.jitify
-    module = cp._core.core.compile_with_cache(
-      source=result.code,
-      options=options,
-      backend=backend,
-      jitify=jitify,
-    )
-    kern = module.get_function(fname)
-    kernel._cache[(in_types, device_id)] = (kern, enable_cg)
-
-  return kern
-
-
-def get_jit_kernel_xla(kernel, c, *ins, outs):
-  # get the input types
-  in_types = []
-  for x in ins:
-    x = c.get_shape(x)
-    in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions())))
-  for x in outs:
-    in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
-  in_types = tuple(in_types)
-  # compile the kernel
-  return _compile_kernel_xla(kernel, in_types)
-
-
-def get_jit_kernel_mlir(kernel, c):
-  # get the input types
-  in_types = []
-  for x in c.avals_in:
-    in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
-  for x in c.avals_out:
-    in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
-  in_types = tuple(in_types)
-  # compile the kernel
-  return _compile_kernel_xla(kernel, in_types)
-
-
-def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
-  kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs'])
-  grid = kwargs.get('grid', None)
-  block = kwargs.get('block', None)
-  shared_mem = kwargs.get('shared_mem', 0)
-  if grid is None or block is None:
-    raise ValueError('The grid and block should be specified for the cupy kernel.')
-
-  # preprocess
-  import_brainpylib_gpu_ops()
-  opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
-  # create custom call
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    b'cupy_kernel_call_gpu',
-    operands=ins,
-    operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
-    shape_with_layout=xla_client.Shape.tuple_shape(
-      [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
-       for value in kwargs['outs']]
-    ),
-    opaque=opaque,
-  )
-
-
-def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel):
-  xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel)
-
-
-def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
-  kernel_func = get_jit_kernel_mlir(kernel, c)
-  grid = kwargs.get('grid', None)
-  block = kwargs.get('block', None)
-  shared_mem = kwargs.get('shared_mem', 0)
-  if grid is None or block is None:
-    raise ValueError('The grid and block should be specified for the cupy kernel.')
-
-  # preprocess
-  import_brainpylib_gpu_ops()
-  opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])
-
-  input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
-  result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
-  output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]
-
-  return custom_call(
-    call_target_name='cupy_kernel_call_gpu',
-    operands=ins,
-    operand_layouts=list(input_layouts),
-    result_layouts=list(output_layouts),
-    result_types=list(result_types),
-    backend_config=opaque,
-    has_side_effect=False,
-  ).results
-
-
-def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel):
-  if cp is None:
-    raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')
-
-  rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel)
-  mlir.register_lowering(primitive, rule, platform='gpu')
diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py
deleted file mode 100644
index 35c9beef6..000000000
--- a/brainpy/_src/math/op_register/numba_approach/__init__.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# -*- coding: utf-8 -*-
-import ctypes
-import ctypes
-from functools import partial
-from typing import Callable
-from typing import Union, Sequence
-
-import jax
-from jax.interpreters import xla, batching, ad, mlir
-
-from jax.tree_util import tree_map
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import import_numba
-from brainpy._src.math.ndarray import Array
-from brainpy._src.math.object_transform.base import BrainPyObject
-
-from brainpy.errors import PackageMissingError
-from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule
-
-numba = import_numba(error_if_not_found=False)
-if numba is not None:
-  from numba import types, carray, cfunc
-
-__all__ = [
-  'CustomOpByNumba',
-  'register_op_with_numba_xla',
-  'compile_cpu_signature_with_numba',
-]
-
-
-def _transform_to_shapedarray(a):
-  return jax.core.ShapedArray(a.shape, a.dtype)
-
-
-def convert_shapedarray_to_shapedtypestruct(shaped_array):
-  return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype)
-
-
-class CustomOpByNumba(BrainPyObject):
-  """Creating a XLA custom call operator with Numba JIT on CPU backend.
-
-  Parameters
-  ----------
-  name: str
-    The name of operator.
-  eval_shape: callable
-    The function to evaluate the shape and dtype of the output according to the input.
-    This function should receive the abstract information of inputs, and return the
-    abstract information of the outputs. For example:
-
-    >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...):
-    >>>   return out1_info, out2_info
-  con_compute: callable
-    The function to make the concrete computation. This function receives inputs,
-    and returns outputs. For example:
-
-    >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...):
-    >>>   pass
-  """
-
-  def __init__(
-      self,
-      eval_shape: Callable = None,
-      con_compute: Callable = None,
-      name: str = None,
-      batching_translation: Callable = None,
-      jvp_translation: Callable = None,
-      transpose_translation: Callable = None,
-      multiple_results: bool = True,
-  ):
-    super().__init__(name=name)
-
-    # abstract evaluation function
-    if eval_shape is None:
-      raise ValueError('Must provide "eval_shape" for abstract evaluation.')
-    self.eval_shape = eval_shape
-
-    # cpu function
-    cpu_func = con_compute
-
-    # register OP
-    if jax.__version__ > '0.4.23':
-      self.op_method = 'mlir'
-      self.op = register_op_with_numba_mlir(
-        self.name,
-        cpu_func=cpu_func,
-        out_shapes=eval_shape,
-        gpu_func_translation=None,
-        batching_translation=batching_translation,
-        jvp_translation=jvp_translation,
-        transpose_translation=transpose_translation,
-        multiple_results=multiple_results,
-      )
-    else:
-      self.op_method = 'xla'
-      self.op = register_op_with_numba_xla(
-        self.name,
-        cpu_func=cpu_func,
-        out_shapes=eval_shape,
-        batching_translation=batching_translation,
-        jvp_translation=jvp_translation,
-        transpose_translation=transpose_translation,
-        multiple_results=multiple_results,
-      )
-
-  def __call__(self, *args, **kwargs):
-    args = tree_map(lambda a: a.value if isinstance(a, Array) else a,
-                    args, is_leaf=lambda a: isinstance(a, Array))
-    kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a,
-                      kwargs, is_leaf=lambda a: isinstance(a, Array))
-    res = self.op.bind(*args, **kwargs)
-    return res
-
-
-def register_op_with_numba_xla(
-    op_name: str,
-    cpu_func: Callable,
-    out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
-    gpu_func_translation: Callable = None,
-    batching_translation: Callable = None,
-    jvp_translation: Callable = None,
-    transpose_translation: Callable = None,
-    multiple_results: bool = False,
-):
-  """
-  Converting the numba-jitted function in a Jax/XLA compatible primitive.
-
-  Parameters
-  ----------
-  op_name: str
-    Name of the operators.
-
-  cpu_func: Callable
-    A callable numba-jitted function or pure function (can be lambda function) running on CPU.
-
-  out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None
-    Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or
-    a sequence of `ShapedArray`. If it is a function, it takes as input the argument
-    shapes and dtypes and should return correct output shapes of `ShapedArray`.
-
-  gpu_func_translation: Callable
-    A callable cuda-jitted kernel running on GPU.
-
-  batching_translation: Callable
-    The batching translation for the primitive.
-
-  jvp_translation: Callable
-    The forward autodiff translation rule.
-
-  transpose_translation: Callable
-    The backward autodiff translation rule.
-
-  multiple_results: bool
-    Whether the primitive returns multiple results. Default is False.
-
-  Returns
-  -------
-  op: core.Primitive
-    A JAX Primitive object.
-  """
-
-  if numba is None:
-    raise PackageMissingError.by_purpose('numba', 'custom op with numba')
-
-  if out_shapes is None:
-    raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or '
-                       'a sequence of `ShapedArray`. If it is a function, it takes as input the argument '
-                       'shapes and dtypes and should return correct output shapes of `ShapedArray`.')
-
-  prim = jax.core.Primitive(op_name)
-  prim.multiple_results = multiple_results
-
-  # user defined function
-  from numba.core.dispatcher import Dispatcher
-  if not isinstance(cpu_func, Dispatcher):
-    cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func)
-
-  # output shape evaluation function
-  def abs_eval_rule(*input_shapes, **info):
-    if callable(out_shapes):
-      shapes = out_shapes(*input_shapes, **info)
-    else:
-      shapes = out_shapes
-
-    if isinstance(shapes, jax.core.ShapedArray):
-      assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
-    elif isinstance(shapes, (tuple, list)):
-      assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
-      for elem in shapes:
-        if not isinstance(elem, jax.core.ShapedArray):
-          raise ValueError(f'Elements in "out_shapes" must be instances of '
-                           f'jax.abstract_arrays.ShapedArray, but we got '
-                           f'{type(elem)}: {elem}')
-    else:
-      raise ValueError(f'Unknown type {type(shapes)}, only '
-                       f'supports function, ShapedArray or '
-                       f'list/tuple of ShapedArray.')
-    return shapes
-
-  # cpu function
-  prim.def_abstract_eval(abs_eval_rule)
-  prim.def_impl(partial(xla.apply_primitive, prim))
-  xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation,
-                                                           cpu_func,
-                                                           abs_eval_rule,
-                                                           multiple_results)
-
-  # gpu function
-  if gpu_func_translation is not None:
-    xla.backend_specific_translations['gpu'][prim] = gpu_func_translation
-
-  # batching
-  if batching_translation is not None:
-    batching.primitive_batchers[prim] = batching_translation
-
-  # jvp
-  if jvp_translation is not None:
-    ad.primitive_jvps[prim] = jvp_translation
-
-  # transpose
-  if transpose_translation is not None:
-    ad.primitive_transposes[prim] = transpose_translation
-
-  return prim
-
-
-def register_op_with_numba_mlir(
-    op_name: str,
-    cpu_func: Callable,
-    out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
-    gpu_func_translation: Callable = None,
-    batching_translation: Callable = None,
-    jvp_translation: Callable = None,
-    transpose_translation: Callable = None,
-    multiple_results: bool = False,
-):
-  if numba is None:
-    raise PackageMissingError.by_purpose('numba', 'custom op with numba')
-
-  if out_shapes is None:
-    raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or '
-                       'a sequence of `ShapedArray`. If it is a function, it takes as input the argument '
-                       'shapes and dtypes and should return correct output shapes of `ShapedArray`.')
-
-  prim = jax.core.Primitive(op_name)
-  prim.multiple_results = multiple_results
-
-  from numba.core.dispatcher import Dispatcher
-  if not isinstance(cpu_func, Dispatcher):
-    cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func)
-
-  def abs_eval_rule(*input_shapes, **info):
-    if callable(out_shapes):
-      shapes = out_shapes(*input_shapes, **info)
-    else:
-      shapes = out_shapes
-
-    if isinstance(shapes, jax.core.ShapedArray):
-      assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
-    elif isinstance(shapes, (tuple, list)):
-      assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
-      for elem in shapes:
-        if not isinstance(elem, jax.core.ShapedArray):
-          raise ValueError(f'Elements in "out_shapes" must be instances of '
-                           f'jax.abstract_arrays.ShapedArray, but we got '
-                           f'{type(elem)}: {elem}')
-    else:
-      raise ValueError(f'Unknown type {type(shapes)}, only '
-                       f'supports function, ShapedArray or '
-                       f'list/tuple of ShapedArray.')
-    return shapes
-
-  prim.def_abstract_eval(abs_eval_rule)
-  prim.def_impl(partial(xla.apply_primitive, prim))
-
-  cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule,
-                                 cpu_func,
-                                 False)
-
-  mlir.register_lowering(prim, cpu_translation_rule, platform='cpu')
-
-  if gpu_func_translation is not None:
-    mlir.register_lowering(prim, gpu_func_translation, platform='gpu')
-
-  if batching_translation is not None:
-    jax.interpreters.batching.primitive_batchers[prim] = batching_translation
-
-  if jvp_translation is not None:
-    jax.interpreters.ad.primitive_jvps[prim] = jvp_translation
-
-  if transpose_translation is not None:
-    jax.interpreters.ad.primitive_transposes[prim] = transpose_translation
-
-  return prim
diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py
deleted file mode 100644
index 363ce6b17..000000000
--- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import ctypes
-
-from jax import dtypes, numpy as jnp
-from jax.core import ShapedArray
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-from jax.interpreters import mlir
-
-from brainpy._src.dependency_check import import_numba
-from brainpy._src.math.op_register.utils import _shape_to_layout
-
-numba = import_numba(error_if_not_found=False)
-ctypes.pythonapi.PyCapsule_New.argtypes = [
-  ctypes.c_void_p,  # void* pointer
-  ctypes.c_char_p,  # const char *name
-  ctypes.c_void_p,  # PyCapsule_Destructor destructor
-]
-ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
-
-__all__ = [
-  '_cpu_translation',
-  'compile_cpu_signature_with_numba',
-  '_numba_mlir_cpu_translation_rule',
-]
-
-if numba is not None:
-  from numba import types, carray, cfunc
-
-
-def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info):
-  target_name, inputs, input_shapes, xla_output_shapes = \
-    compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info)
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    target_name,
-    operands=inputs,
-    operand_shapes_with_layout=input_shapes,
-    shape_with_layout=xla_output_shapes,
-  )
-
-
-def _cpu_signature(
-    func,
-    input_dtypes,
-    input_shapes,
-    output_dtypes,
-    output_shapes,
-    multiple_results: bool,
-    debug: bool = False
-):
-  code_scope = dict(
-    func_to_call=func,
-    input_shapes=input_shapes,
-    input_dtypes=input_dtypes,
-    output_shapes=output_shapes,
-    output_dtypes=output_dtypes,
-    carray=carray,
-  )
-
-  # inputs
-  if len(input_shapes) > 1:
-    args_in = [
-      f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
-      for i in range(len(input_shapes))
-    ]
-    args_in = '(\n    ' + "\n    ".join(args_in) + '\n  )'
-  else:
-    args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
-
-  # outputs
-  if multiple_results:
-    args_out = [
-      f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
-      for i in range(len(output_shapes))
-    ]
-    args_out = '(\n    ' + "\n    ".join(args_out) + '\n  )'
-  else:
-    args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
-
-  # function body
-  code_string = '''
-def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
-  args_out = {args_out}
-  args_in = {args_in}
-  func_to_call(args_out, args_in)
-    '''.format(args_in=args_in,
-               args_out=args_out)
-  if debug: print(code_string)
-  exec(compile(code_string.strip(), '', 'exec'), code_scope)
-
-  new_f = code_scope['xla_cpu_custom_call_target']
-  if multiple_results:
-    xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr),
-                                  types.CPointer(types.voidptr)))(new_f)
-  else:
-    xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f)
-  target_name = xla_c_rule.native_name.encode("ascii")
-  capsule = ctypes.pythonapi.PyCapsule_New(
-    xla_c_rule.address,  # A CFFI pointer to a function
-    b"xla._CUSTOM_CALL_TARGET",  # A binary string
-    None  # PyCapsule object run at destruction
-  )
-  xla_client.register_custom_call_target(target_name, capsule, "cpu")
-  return target_name
-
-
-def compile_cpu_signature_with_numba(
-    c,
-    func,
-    abs_eval_fn,
-    multiple_results,
-    inputs: tuple,
-    description: dict = None,
-):
-  input_layouts = [c.get_shape(arg) for arg in inputs]
-  info_inputs = []
-  if description is None: description = dict()
-  for v in description.values():
-    if isinstance(v, (int, float)):
-      input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ()))
-      info_inputs.append(xla_client.ops.ConstantLiteral(c, v))
-    elif isinstance(v, (tuple, list)):
-      v = jnp.asarray(v)
-      input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1))))
-      info_inputs.append(xla_client.ops.Constant(c, v))
-    else:
-      raise TypeError
-  input_layouts = tuple(input_layouts)
-  input_dtypes = tuple(shape.element_type() for shape in input_layouts)
-  input_dimensions = tuple(shape.dimensions() for shape in input_layouts)
-  output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type())
-                                              for shape in input_layouts[:len(inputs)]),
-                                       **description)
-  if isinstance(output_abstract_arrays, ShapedArray):
-    output_abstract_arrays = (output_abstract_arrays,)
-    assert not multiple_results
-  else:
-    assert multiple_results
-  output_shapes = tuple(array.shape for array in output_abstract_arrays)
-  output_dtypes = tuple(array.dtype for array in output_abstract_arrays)
-  output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes)
-  target_name = _cpu_signature(func,
-                               input_dtypes,
-                               input_dimensions,
-                               output_dtypes,
-                               output_shapes,
-                               multiple_results,
-                               debug=False)
-  output_layouts = [xla_client.Shape.array_shape(*arg)
-                    for arg in zip(output_dtypes, output_shapes, output_layouts)]
-  output_layouts = (xla_client.Shape.tuple_shape(output_layouts)
-                    if multiple_results else
-                    output_layouts[0])
-  return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts
-
-
-def _numba_mlir_cpu_translation_rule(
-    cpu_func,
-    debug,
-    ctx,
-    *ins,
-    **kwargs
-):
-  # output information
-  outs = ctx.avals_out
-  output_shapes = tuple([out.shape for out in outs])
-  output_dtypes = tuple([out.dtype for out in outs])
-  output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
-  result_types = [mlir.aval_to_ir_type(out) for out in outs]
-
-  # input information
-  avals_in = ctx.avals_in
-  input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
-  input_dtypes = tuple(inp.dtype for inp in avals_in)
-  input_shapes = tuple(inp.shape for inp in avals_in)
-
-  # compiling function
-  code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes,
-                    output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
-  if len(input_shapes) > 1:
-    args_in = [
-      f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
-      for i in range(len(input_shapes))
-    ]
-    args_in = '(\n    ' + "\n    ".join(args_in) + '\n  )'
-  else:
-    args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
-  if len(output_shapes) > 1:
-    args_out = [
-      f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
-      for i in range(len(output_shapes))
-    ]
-    args_out = '(\n    ' + "\n    ".join(args_out) + '\n  )'
-    sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
-  else:
-    args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
-    sig = types.void(types.voidptr, types.CPointer(types.voidptr))
-  # args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
-  code_string = '''
-def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
-    args_out = {args_out}
-    args_in = {args_in}
-    func_to_call(args_out, args_in)
-  '''.format(args_in=args_in,
-             args_out=args_out)
-
-  if debug:
-    print(code_string)
-  exec(compile(code_string.strip(), '', 'exec'), code_scope)
-  new_f = code_scope['numba_cpu_custom_call_target']
-
-  # register
-  xla_c_rule = cfunc(sig)(new_f)
-  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
-  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
-  xla_client.register_custom_call_target(target_name, capsule, "cpu")
-
-  # call
-  return custom_call(
-    call_target_name=target_name,
-    operands=ins,
-    operand_layouts=list(input_layouts),
-    result_layouts=list(output_layouts),
-    result_types=list(result_types),
-    has_side_effect=False,
-  ).results
diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py
deleted file mode 100644
index 21099cb61..000000000
--- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import jax.core
-import pytest
-from jax.core import ShapedArray
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-if numba is None:
-  pytest.skip('no numba', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-
-def eval_shape(a):
-  b = ShapedArray(a.shape, dtype=a.dtype)
-  return b
-
-@numba.njit(parallel=True)
-def con_compute(outs, ins):
-  b = outs
-  a = ins
-  b[:] = a + 1
-
-def test_CustomOpByNumba_single_result():
-  op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
-  print(op(bm.zeros(10)))
-
-def eval_shape2(a, b):
-  c = ShapedArray(a.shape, dtype=a.dtype)
-  d = ShapedArray(b.shape, dtype=b.dtype)
-  return c, d
-
-def con_compute2(outs, ins):
-  c = outs[0]  # take out all the outputs
-  d = outs[1]
-  a = ins[0]  # take out all the inputs
-  b = ins[1]
-  # c, d = outs
-  # a, b = ins
-  c[:] = a + 1
-  d[:] = b * 2
-
-def test_CustomOpByNumba_multiple_results():
-  op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
-  print(op2(bm.zeros(10), bm.ones(10)))
-
-test_CustomOpByNumba_multiple_results()
\ No newline at end of file
diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py
deleted file mode 100644
index f461f4277..000000000
--- a/brainpy/_src/math/op_register/numba_based.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import ctypes
-from functools import partial
-
-from jax.interpreters import xla, mlir
-from jax.lib import xla_client
-from jaxlib.hlo_helpers import custom_call
-
-from brainpy._src.dependency_check import import_numba
-from brainpy.errors import PackageMissingError
-from .utils import _shape_to_layout
-
-numba = import_numba(error_if_not_found=False)
-if numba is not None:
-  from numba import types, carray, cfunc
-
-__all__ = [
-  'register_numba_xla_cpu_translation_rule',
-  'register_numba_mlir_cpu_translation_rule',
-]
-
-#                                         [void* pointer,
-#                                          const char *name,
-#                                          PyCapsule_Destructor destructor]
-ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
-ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
-
-
-def _cpu_signature(
-    kernel,
-    input_dtypes,
-    input_shapes,
-    output_dtypes,
-    output_shapes,
-    debug: bool = False
-):
-  code_scope = dict(
-    func_to_call=kernel,
-    input_shapes=input_shapes,
-    input_dtypes=input_dtypes,
-    output_shapes=output_shapes,
-    output_dtypes=output_dtypes,
-    carray=carray,
-  )
-
-  # inputs, outputs, arguments
-  args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
-             for i in range(len(input_shapes))]
-  args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
-              for i in range(len(output_shapes))]
-  args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
-
-  # function body
-  code_string = '''
-  def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
-    {args_in}
-    {args_out}
-    func_to_call({args_call})
-    '''.format(args_in="\n    ".join(args_in),
-               args_out="\n    ".join(args_out),
-               args_call=", ".join(args_call))
-  if debug: print(code_string)
-  exec(compile(code_string.strip(), '', 'exec'), code_scope)
-
-  # register
-  new_f = code_scope['xla_cpu_custom_call_target']
-  xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f)
-  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
-  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
-  xla_client.register_custom_call_target(target_name, capsule, "cpu")
-
-  return target_name
-
-
-def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs):
-  outs = kwargs['outs']
-
-  # output information
-  output_shapes = tuple(out.shape for out in outs)
-  output_dtypes = tuple(out.dtype for out in outs)
-  output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes)
-  output_infos = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)]
-  output_infos = xla_client.Shape.tuple_shape(output_infos)
-
-  # input information
-  input_layouts = tuple(c.get_shape(arg) for arg in ins)
-  input_dtypes = tuple(inp.element_type() for inp in input_layouts)
-  input_shapes = tuple(inp.dimensions() for inp in input_layouts)
-
-  # compiling
-  target_name = _cpu_signature(kernel,
-                               input_dtypes,
-                               input_shapes,
-                               output_dtypes,
-                               output_shapes,
-                               debug=debug)
-
-  # call
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    target_name.encode("ascii"),
-    operands=tuple(ins),
-    operand_shapes_with_layout=input_layouts,
-    shape_with_layout=output_infos,
-  )
-
-
-def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False):
-  if numba is None:
-    raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule')
-
-  # do not support after jax >= 0.4.24
-  xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule,
-                                                                cpu_kernel,
-                                                                debug)
-
-
-def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
-  # output information
-  outs = ctx.avals_out
-  output_shapes = tuple([out.shape for out in outs])
-  output_dtypes = tuple([out.dtype for out in outs])
-  output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
-  result_types = [mlir.aval_to_ir_type(out) for out in outs]
-
-  # input information
-  avals_in = ctx.avals_in
-  input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
-  input_dtypes = tuple(inp.dtype for inp in avals_in)
-  input_shapes = tuple(inp.shape for inp in avals_in)
-
-  # compiling function
-  code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
-                    output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
-  args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
-             for i in range(len(input_shapes))]
-  if len(output_shapes) > 1:
-    args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
-                for i in range(len(output_shapes))]
-    sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
-  else:
-    args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
-    sig = types.void(types.voidptr, types.CPointer(types.voidptr))
-  args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
-  code_string = '''
-def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
-    {args_in}
-    {args_out}
-    func_to_call({args_call})
-      '''.format(args_in="\n    ".join(args_in),
-                 args_out="\n    ".join(args_out),
-                 args_call=", ".join(args_call))
-  if debug:
-    print(code_string)
-  exec(compile(code_string.strip(), '', 'exec'), code_scope)
-  new_f = code_scope['numba_cpu_custom_call_target']
-
-  # register
-  xla_c_rule = cfunc(sig)(new_f)
-  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
-  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
-  xla_client.register_custom_call_target(target_name, capsule, "cpu")
-
-  # call
-  return custom_call(
-    call_target_name=target_name,
-    operands=ins,
-    operand_layouts=list(input_layouts),
-    result_layouts=list(output_layouts),
-    result_types=list(result_types),
-    has_side_effect=False,
-  ).results
-
-
-def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False):
-  if numba is None:
-    raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule')
-
-  rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
-  mlir.register_lowering(primitive, rule, platform='cpu')
diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py
deleted file mode 100644
index 2c9f09724..000000000
--- a/brainpy/_src/math/op_register/tests/test_ad_support.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import pytest
-from typing import Tuple
-
-import jax
-from jax import core
-from jax import numpy as jnp
-from jax.interpreters import ad
-
-import brainpy as bp
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-if numba is None:
-  pytest.skip('no numba', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-
-def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ):
-  data = jnp.atleast_1d(bm.as_jax(data))
-  indices = bm.as_jax(indices)
-  indptr = bm.as_jax(indptr)
-  vector = bm.as_jax(vector)
-  if vector.dtype == jnp.bool_:
-    vector = bm.as_jax(vector, dtype=data.dtype)
-  outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)]
-  if transpose:
-    return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
-  else:
-    return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
-
-
-@numba.njit(fastmath=True)
-def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val):
-  res_val.fill(0)
-  if values.shape[0] == 1:
-    values = values[0]
-    for row_i in range(vector.shape[0]):
-      v = vector[row_i]
-      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
-        res_val[col_indices[j]] += values * v
-  else:
-    for row_i in range(vector.shape[0]):
-      v = vector[row_i]
-      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
-        res_val[col_indices[j]] += v * values[j]
-
-
-@numba.njit(fastmath=True, parallel=True, nogil=True)
-def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val):
-  res_val.fill(0)
-  # csr mat @ vec
-  if values.shape[0] == 1:
-    values = values[0]
-    for row_i in numba.prange(res_val.shape[0]):
-      r = 0.
-      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
-        r += values * vector[col_indices[j]]
-      res_val[row_i] = r
-  else:
-    for row_i in numba.prange(res_val.shape[0]):
-      r = 0.
-      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
-        r += values[j] * vector[col_indices[j]]
-      res_val[row_i] = r
-
-
-def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
-  return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
-
-
-def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
-  return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
-
-
-def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs):
-  if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
-    raise ValueError("Cannot transpose with respect to sparse indices.")
-
-  ct = ct[0]
-  if ad.is_undefined_primal(vector):
-    ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose)
-    return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
-
-  else:
-    if type(ct) is ad.Zero:
-      ct_data = ad.Zero(data)
-    else:
-      if data.aval.shape[0] == 1:  # scalar
-        ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
-        ct_data = jnp.inner(ct, ct_data)
-      else:  # heterogeneous values
-        row, col = bm.sparse.csr_to_coo(indices, indptr)
-        ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
-    return ct_data, indices, indptr, vector
-
-
-prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
-prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
-prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)
-
-prim = bm.XLACustomOp(_csr_matvec_numba_imp)
-prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
-prim.def_transpose_rule(_csrmv_cusparse_transpose)
-
-
-def sum_op(op):
-  def func(*args, **kwargs):
-    r = op(*args, **kwargs)
-    return r.sum()
-
-  return func
-
-
-def try_a_trial(transpose, shape):
-  rng = bm.random.RandomState()
-  conn = bp.conn.FixedProb(0.1)
-  indices, indptr = conn(*shape).require('pre2post')
-  indices = bm.as_jax(indices)
-  indptr = bm.as_jax(indptr)
-  heter_data = rng.random(indices.shape)
-  heter_data = bm.as_jax(heter_data)
-  vector = rng.random(shape[0] if transpose else shape[1])
-  vector = bm.as_jax(vector)
-
-  r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs)), argnums=(0, 3))(
-    heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
-  r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))(
-    heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
-  print(r5)
-  print(r6)
-  assert bm.allclose(r5[0], r6[0])
-  assert bm.allclose(r5[1], r6[1][0])
-
-
-def test():
-  transposes = [True, False]
-  shapes = [(100, 200), (10, 1000), (2, 2000)]
-
-  for transpose in transposes:
-    for shape in shapes:
-      try_a_trial(transpose, shape)
diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py
deleted file mode 100644
index 772b61607..000000000
--- a/brainpy/_src/math/op_register/tests/test_cupy_based.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import jax
-import pytest
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi
-
-cp = import_cupy(error_if_not_found=False)
-cp_jit = import_cupy_jit(error_if_not_found=False)
-ti = import_taichi(error_if_not_found=False)
-if cp is None or ti is None:
-  pytest.skip('no cupy or taichi', allow_module_level=True)
-bm.set_platform('cpu')
-
-
-def test_cupy_based():
-  bm.op_register.clear_taichi_aot_caches()
-  # Raw Module
-
-  @ti.kernel
-  def simpleAdd(x1: ti.types.ndarray(ndim=2),
-                x2: ti.types.ndarray(ndim=2),
-                n: ti.types.ndarray(ndim=0),
-                y: ti.types.ndarray(ndim=2)):
-    for i, j in y:
-      y[i, j] = x1[i, j] + x2[i, j]
-
-  source_code = r'''
-  extern "C"{
-  
-  __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)
-  {
-      unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
-      if (tid < N)
-      {
-          y[tid] = x1[tid] + x2[tid];
-      }
-  }
-  }
-  '''
-  N = 10
-  x1 = bm.ones((N, N))
-  x2 = bm.ones((N, N))
-
-  mod = cp.RawModule(code=source_code)
-  kernel = mod.get_function('kernel')
-
-  prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel)
-
-  y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]
-
-  print(y)
-  assert bm.allclose(y, x1 + x2)
-
-  # JIT Kernel
-  @ti.kernel
-  def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1),
-                              size: ti.types.ndarray(ndim=1),
-                              y: ti.types.ndarray(ndim=1)):
-    for i in y:
-      y[i] = x[i]
-
-  @cp_jit.rawkernel()
-  def elementwise_copy(x, size, y):
-    tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x
-    ntid = cp_jit.gridDim.x * cp_jit.blockDim.x
-    for i in range(tid, size, ntid):
-      y[i] = x[i]
-
-  size = 100
-  x = bm.ones((size,))
-
-  prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy)
-
-  y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]
-
-  print(y)
-  assert bm.allclose(y, x)
-
-# test_cupy_based()
diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py
deleted file mode 100644
index f7adc695c..000000000
--- a/brainpy/_src/math/op_register/tests/test_numba_based.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import jax.core
-import pytest
-from jax.core import ShapedArray
-
-import brainpy.math as bm
-from brainpy._src.dependency_check import import_numba
-
-numba = import_numba(error_if_not_found=False)
-if numba is None:
-  pytest.skip('no numba', allow_module_level=True)
-
-bm.set_platform('cpu')
-
-
-@numba.njit(fastmath=True)
-def numba_event_csrmv(weight, indices, vector, outs):
-  outs.fill(0)
-  weight = weight[()]  # 0d
-  for row_i in range(vector.shape[0]):
-    if vector[row_i]:
-      for j in indices[row_i]:
-        outs[j] += weight
-
-
-prim = bm.XLACustomOp(numba_event_csrmv)
-
-
-def call(s=100):
-  indices = bm.random.randint(0, s, (s, 80))
-  vector = bm.random.rand(s) < 0.1
-  out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
-  print(out[0].shape)
-
-
-def test_event_ELL():
-  call(1000)
-  call(100)
-  bm.clear_buffer_memory()
-
-# CustomOpByNumba Test
-
-def eval_shape(a):
-  b = ShapedArray(a.shape, dtype=a.dtype)
-  return b
-
-@numba.njit(parallel=True)
-def con_compute(outs, ins):
-  b = outs
-  a = ins
-  b[:] = a + 1
-
-def test_CustomOpByNumba():
-  op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
-  print(op(bm.zeros(10)))
-  assert bm.allclose(op(bm.zeros(10)), bm.ones(10))
\ No newline at end of file
diff --git a/brainpy/_src/math/op_register/utils.py b/brainpy/_src/math/op_register/utils.py
deleted file mode 100644
index 2a10443db..000000000
--- a/brainpy/_src/math/op_register/utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-from functools import partial
-
-import jax.numpy as jnp
-from jax import lax
-from jax.interpreters import batching
-from jax.tree_util import tree_flatten, tree_unflatten
-
-__all__ = [
-  'register_general_batching',
-]
-
-
-def _general_batching_rule(prim, args, axes, **kwargs):
-  batch_axes, batch_args, non_batch_args = [], {}, {}
-  for ax_i, ax in enumerate(axes):
-    if ax is None:
-      non_batch_args[f'ax{ax_i}'] = args[ax_i]
-    else:
-      batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
-      batch_axes.append(ax_i)
-
-  def f(_, x):
-    pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
-                  for i in range(len(axes))])
-    return 0, prim.bind(*pars, **kwargs)
-
-  _, outs = lax.scan(f, 0, batch_args)
-  out_vals, out_tree = tree_flatten(outs)
-  out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
-  return outs, out_dim
-
-
-def register_general_batching(prim):
-  batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
-
-
-def _shape_to_layout(shape):
-  return tuple(range(len(shape) - 1, -1, -1))
-
diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py
index 13c9e1e28..eec5f53c0 100644
--- a/brainpy/_src/math/sparse/__init__.py
+++ b/brainpy/_src/math/sparse/__init__.py
@@ -1,9 +1,7 @@
 # from ._coo_mv import *
-# from ._bsr_mv import *
 from .csr_mv import *
 from .csr_mm import *
 from .utils import *
-from .bsr_mm import *
 from .jax_prim import *
 
 
diff --git a/brainpy/_src/math/sparse/bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py
deleted file mode 100644
index 19800749d..000000000
--- a/brainpy/_src/math/sparse/bsr_mm.py
+++ /dev/null
@@ -1,462 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from functools import partial
-from typing import Tuple
-
-import jax.lax
-import numpy as np
-from jax import numpy as jnp
-from jax.core import Primitive, ShapedArray
-from jax.interpreters import ad, xla
-from jax.lib import xla_client
-
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
-                                           register_general_batching)
-from brainpy.errors import GPUOperatorNotFound
-
-numba = import_numba(error_if_not_found=False)
-
-__all__ = [
-  'bcsrmm',
-]
-
-
-def get_mask(dense_b, blockshape, blockcount):
-  mask = jnp.zeros(blockcount[0] * blockcount[1], dtype=jnp.bool_)
-
-  for i in range(blockcount[1]):
-    for j in range(blockcount[0]):
-      if jnp.abs(dense_b[i * blockshape[1]: (i + 1) * blockshape[1],
-                 j * blockshape[0]: (j + 1) * blockshape[0]]).sum() != 0:
-        mask = mask.at[i * blockcount[0] + j].set(True)
-  mask = mask.reshape(blockcount[1], blockcount[0])
-  return mask
-
-
-def get_mask_from_ptr_indices(ptr, indices, blockcount):
-  mask = jnp.zeros((blockcount[1], blockcount[0]), dtype=jnp.bool_)
-  for idx, indice in enumerate(indices):
-    row_index = 0
-    for ptr_ in ptr[1:]:
-      if idx < ptr_:
-        break
-      row_index += 1
-    mask = mask.at[row_index, indice].set(True)
-  return mask
-
-
-def get_data(dense_b, mask, blockshape, blockcount, n_blocks):
-  data = jnp.zeros(
-    shape=(n_blocks * blockshape[1], blockshape[0]),
-    dtype=jnp.float32
-  )
-
-  assignment_count = 0
-  for i in range(blockcount[1]):
-    for j in range(blockcount[0]):
-      if mask[i, j]:
-        data = data.at[assignment_count * blockshape[1]: (assignment_count + 1) * blockshape[1],
-               :].set(dense_b[i * blockshape[1]: (i + 1) * blockshape[1],
-                      j * blockshape[0]: (j + 1) * blockshape[0]])
-        assignment_count += 1
-  return data
-
-
-def get_ptr_indices(mask, blockcount, n_blocks, block_ptr=None):
-  nnz = jnp.nonzero(mask)
-
-  if block_ptr is None:
-    block_ptr = jnp.arange(0, len(nnz[0]))
-
-  indices = jnp.argsort(block_ptr)
-  _ = jnp.take(block_ptr, indices)
-
-  blocks = nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]
-  blocks = jnp.stack([nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]], axis=-1).astype(
-    dtype=jnp.int32
-  )
-  blocks = jnp.flip(blocks, axis=-1).flatten()
-
-  X = blockcount[1]
-  Y = blockcount[0]
-
-  rows = nnz[0][:]
-  cols = nnz[1][:]
-
-  block_indices = jnp.zeros(X * Y, dtype=jnp.int32)
-  positions = rows * Y + cols
-  block_indices = block_indices.at[positions].set(block_ptr + 1)
-  block_indices = block_indices.reshape(X, Y).transpose().reshape(X * Y)
-
-  block_ptr = block_indices[jnp.nonzero(block_indices)[0]] - 1
-
-  X, Y = Y, X
-  rows = cols
-  nnztt = jnp.nonzero(mask.transpose())
-  cols = nnztt[:][1]
-
-  rows.astype(jnp.int32)
-
-  ptr_b = jnp.zeros((X + 1,), dtype=jnp.int32)
-  for row in rows:
-    ptr_b = ptr_b.at[row + 1].set(ptr_b[row + 1] + 1)
-  ptr_b = ptr_b.cumsum(0).astype(dtype=jnp.int32)
-
-  indices_b = jnp.stack([cols, block_ptr], axis=1).astype(dtype=jnp.int32)
-
-  return ptr_b, indices_b
-
-
-def get_dense(ptr, indices, data, shape, blockshape):
-  mask = get_mask_from_ptr_indices(ptr, indices, blockshape)
-  dense_data = jnp.zeros(shape, dtype=jnp.float32)
-  mask_count = 0
-  for i in range(mask.shape[1]):
-    for j in range(mask.shape[0]):
-      if mask[i, j]:
-        dense_data = dense_data.at[
-                     i * blockshape[0]: (i + 1) * blockshape[0],
-                     j * blockshape[1]: (j + 1) * blockshape[1],
-                     ].set(data[mask_count * blockshape[0]: (mask_count + 1) * blockshape[0], :])
-        mask_count += 1
-  return dense_data
-
-
-def blocksparse_matmat_multiply(dense_a,
-                                ptr_b=None,
-                                indices_b=None,
-                                data_b=None,
-                                shape_b=None,
-                                dense_b=None,
-                                blockshape=(32, 32),
-                                device='cpu'):
-  if dense_b is not None:
-    # m, n, k
-    m = dense_a.shape[0]
-    k = dense_a.shape[1]
-    n = dense_b.shape[1]
-
-    # blockcount
-    blockcount = (n // blockshape[0], k // blockshape[1])
-
-    # mask
-    mask = get_mask(dense_b, blockshape, blockcount)
-
-    # n_blocks
-    n_blocks = mask.sum()
-
-    # data_b
-    data_b = get_data(dense_b, mask, blockshape, blockcount, n_blocks)
-
-    # ptr_b, indices_b
-    ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks)
-  else:
-    # m, n, k
-    m = dense_a.shape[0]
-    n = shape_b[1]
-    k = dense_a.shape[1]
-
-    # blockcount
-    blockcount = (n // blockshape[0], k // blockshape[1])
-
-    mask = get_mask_from_ptr_indices(ptr_b, indices_b, blockcount)
-
-    n_blocks = mask.sum()
-
-    ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks)
-
-  # out
-  # out = jnp.zeros((n, m))
-
-  # verbose
-  print('data_b: ', data_b)
-  print('ptr:', ptr_b)
-  print('indices:', indices_b)
-
-  '''out = blocksparse_matmat_cpu_test(dense_a,
-          ptr_b,
-          indices_b,
-          data_b,
-          out,
-          m=m,
-          n=n,
-          k=k,
-          block_size_k=blockshape[0],
-          block_size_n=blockshape[1])
-  return out'''
-
-  if device == 'cpu':
-    out = bcsrmm(
-      dense_a,
-      ptr_b,
-      indices_b,
-      data_b,
-      m=m,
-      n=n,
-      k=k,
-      block_size_k=blockshape[0],
-      block_size_n=blockshape[1],
-    )
-    return out
-  elif device == 'gpu':
-    out = bcsrmm(
-      dense_a,
-      ptr_b,
-      indices_b,
-      data_b,
-      m=m,
-      n=n,
-      k=k,
-      block_size_k=blockshape[0],
-      block_size_n=blockshape[1],
-    )
-    return out.transpose()
-  else:
-    raise Exception('Invalid device: ', device)
-
-
-def bcsrmm(
-    A_data: jax.Array,
-    B_data: jax.Array,
-    B_indices: jax.Array,
-    B_ptr: jax.Array,
-    *,
-    shape: Tuple[int, int],
-    block_size: Tuple[int, int],
-    transpose: bool = False,
-    method: str = 'cutlass'
-) -> jax.Array:
-  """Perform the matrix multiplication :math:`C = A @ B` with BSR data structure.
-
-  Args:
-    A_data: The dense matrix :math:`A`.
-    B_data: The data at each block of :math:`B`.
-    B_indices: The sparse indices of :math:`B`.
-    B_ptr: The connection pointer of :math:`B`.
-    shape: a tuple of int, indicating the array shape of :math:`B`.
-    block_size: a tuple of int, indicating the block size for portioning :math:`B`.
-    transpose: boolean. If True, perform :math:`A @ B^T`; otherwise, perform :math:`A @ B`.
-    method: a sting for denoting the BSR sparse computing method.
-
-  Returns:
-    The dense array :math:`C`.
-  """
-  A_data = as_jax(A_data)
-  B_data = as_jax(B_data)
-  B_indices = as_jax(B_indices)
-  B_ptr = as_jax(B_ptr)
-  assert A_data.shape[1] == shape[0]
-
-  if method == 'cutlass':
-    C = _bcsrmm_cutlass_p.bind(A_data,
-                               B_data,
-                               B_indices,
-                               B_ptr,
-                               m=A_data.shape[0],
-                               k=shape[0],
-                               n=shape[1],
-                               transpose=transpose,
-                               block_size_k=block_size[0],
-                               block_size_n=block_size[1])[0]
-    return C.T
-  else:
-    raise ValueError
-
-
-if numba is not None:
-  @numba.njit(fastmath=True, parallel=True, nogil=True)
-  def _bcsrmm_cutlass_imp_transpose(outs, ins):  # dense(m, k) @ bcsr(n, k) -> dense(n, m)
-    res_val = outs[0]
-    # B_data: (num_block, block_size_k, block_size_n)
-    A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins
-    block_size_k = block_size_k[()]
-    block_size_n = block_size_n[()]
-    n_block = n // block_size_n
-
-    for ni in numba.prange(n_block):
-      C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype)
-      start, end = B_inptr[ni], B_inptr[ni + 1]
-      ns = ni * block_size_n
-      ne = ns + block_size_n
-      for i in range(start, end):
-        ki = B_indices[i, 0]
-        ks = ki * block_size_k
-        ke = ki + block_size_k
-        bi = B_indices[i, 1]
-        C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T)
-      res_val[ns: ne] = C_tmp
-    return res_val
-
-
-  @numba.njit(fastmath=True, parallel=True, nogil=True)
-  def _bcsrmm_cutlass_imp2(outs, ins):  # dense(m, k) @ bcsr(k, n) -> dense(n, m)
-    res_val = outs[0]
-    # B_data: (num_block, block_size_n, block_size_k)
-    A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins
-    block_size_k = block_size_k[()]
-    block_size_n = block_size_n[()]
-    n_block = n // block_size_n
-
-    for ni in numba.prange(n_block):
-      C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype)
-      start, end = B_inptr[ni], B_inptr[ni + 1]
-      ns = ni * block_size_n
-      ne = ns + block_size_n
-      for i in range(start, end):
-        ki = B_indices[i, 0]
-        ks = ki * block_size_k
-        ke = ki + block_size_k
-        bi = B_indices[i, 1]
-        C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T)
-      res_val[ns: ne] = C_tmp
-    return res_val
-
-
-def _bcsrmm_cutlass_abstract(
-    A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n
-):
-  assert block_size_k == 32, 'cutlass based block-sparse mm only support block size (32, 32)'
-  assert block_size_n == 32, 'cutlass based block-sparse mm only support block size (32, 32)'
-  assert B_indices.shape[0] * block_size_n == B_data.shape[0]
-  assert block_size_k == B_data.shape[1]
-  assert A_data.shape[0] == m
-  assert A_data.shape[1] == k
-  assert A_data.dtype == B_data.dtype
-  assert n // block_size_n + 1 == B_ptr.shape[0]
-  return [ShapedArray(dtype=A_data.dtype, shape=(n, m))]
-
-
-def _bcsrmm_cutlass_cpu_translation(
-    c, A_data, B_data, B_indices, B_ptr, *,
-    m, k, n, block_size_k, block_size_n
-):
-  inputs = (A_data, B_ptr, B_indices, B_data)
-  description = dict(m=m,
-                     n=n,
-                     k=k,
-                     block_size_k=block_size_k,
-                     block_size_n=block_size_n)
-  name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba(
-    c,
-    _bcsrmm_cutlass_imp2,
-    abs_eval_fn=_bcsrmm_cutlass_abstract,
-    multiple_results=True,
-    inputs=inputs,
-    description=description
-  )
-  return xla_client.ops.CustomCallWithLayout(
-    c, name,
-    operands=inputs,
-    operand_shapes_with_layout=in_layouts,
-    shape_with_layout=out_layouts,
-  )
-
-
-def _bcsrmm_cutlass_gpu_translation(c, A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n):
-  gpu_ops = import_brainpylib_gpu_ops()
-  if gpu_ops is None:
-    raise GPUOperatorNotFound(_bcsrmm_cutlass_p.name)
-
-  matrix_info = c.get_shape(A_data)
-  dtype = matrix_info.element_type()
-
-  opaque = gpu_ops.build_blocksparse_format_descriptor(m,
-                                                       n,
-                                                       k,
-                                                       block_size_k,
-                                                       block_size_n)
-
-  fn = b'gpu_blocksparse_matmat'
-
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    fn,
-    operands=(A_data, B_ptr, B_indices, B_data,),
-    operand_shapes_with_layout=(c.get_shape(A_data),
-                                c.get_shape(B_ptr),
-                                c.get_shape(B_indices),
-                                c.get_shape(B_data),),
-    shape_with_layout=xla_client.Shape.tuple_shape(
-      (xla_client.Shape.array_shape(dtype, (n, m), (1, 0)),)
-    ),
-    opaque=opaque
-  )
-
-
-def _bcsrmm_cutlass_jvp_dense_a(dense_a_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k,
-                                block_size_n):
-  return bcsrmm(dense_a_dot, B_ptr, B_indices, B_data, m=m, n=n, k=k, block_size_k=block_size_k,
-                block_size_n=block_size_n)
-
-
-def _bcsrmm_cutlass_jvp_data_b(data_b_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k,
-                               block_size_n):
-  return bcsrmm(A_data, B_ptr, B_indices, data_b_dot, m=m, n=n, k=k, block_size_k=block_size_k,
-                block_size_n=block_size_n)
-
-
-def _bcsrmm_cutlass_jvp_transpose():
-  # TODO: implement
-  pass
-
-
-_bcsrmm_cutlass_p = Primitive('bcsrmm_cutlass_pim')
-_bcsrmm_cutlass_p.multiple_results = True
-_bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_cutlass_abstract)
-_bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p))
-# xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation
-# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation
-ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
-ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose
-register_general_batching(bcsrmm)
-
-
-def _blocksparse_matmat_back_abstract(
-    A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len
-):
-  shape = (n, k)
-  dtype = A_data.dtype
-  out = ShapedArray(dtype=dtype, shape=shape)
-  return [out]
-
-
-def _blocksparse_matmat_back_gpu_translation(
-    c, A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len
-):
-  gpu_ops = import_brainpylib_gpu_ops()
-  if gpu_ops is None:
-    raise GPUOperatorNotFound(_bcsrmm_cutlass_back_p.name)
-  matrix_info = c.get_shape(A_data)
-  dtype = matrix_info.element_type()
-
-  opaque = gpu_ops.build_blocksparse_back_format_descriptor(m,
-                                                            n,
-                                                            k,
-                                                            block_size_k,
-                                                            block_size_n,
-                                                            blocks_len)
-
-  fn = b'gpu_blocksparse_matmat_back'
-
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    fn,
-    operands=(A_data, B_data, blocks,),
-    operand_shape_with_layout=(c.get_shape(A_data),
-                               c.get_shape(B_data),
-                               c.get_shape(blocks),),
-    shape_with_layout=xla_client.Shape.tuple_shape(
-      (xla_client.Shape.array_shape(dtype, (k, n), (1, 0)),)
-    ),
-    opaque=opaque
-  )
-
-
-_bcsrmm_cutlass_back_p = Primitive('bcsrmm_cutlass_back_prim')
-_bcsrmm_cutlass_back_p.multiple_results = True
-_bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract)
-_bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p))
-# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation
-register_general_batching(_bcsrmm_cutlass_back_p)
diff --git a/brainpy/_src/math/sparse/bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py
deleted file mode 100644
index 7dc0b683d..000000000
--- a/brainpy/_src/math/sparse/bsr_mv.py
+++ /dev/null
@@ -1,210 +0,0 @@
-from functools import partial
-from typing import Union, Tuple
-
-import numba
-import numpy as np
-from jax import numpy as jnp
-from jax.core import ShapedArray, Primitive
-from jax.interpreters import ad, xla
-from jax.lib import xla_client
-
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import (compile_cpu_signature_with_numba,
-                                           register_general_batching)
-from brainpy._src.math.sparse.utils import csr_to_coo
-from brainpy._src.dependency_check import import_brainpylib_gpu_ops
-from brainpy.errors import GPUOperatorNotFound
-
-__all__ = [
-  'cusparse_bcsr_matvec'
-]
-
-
-@numba.njit(fastmath=True, parallel=True, nogil=True)
-def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins):
-  data, indices, indptr, vector, blocksize, shape, nnzb, transpose = ins
-  blocksize = blocksize[()]
-  outs.fill(0)
-  for i in range(shape[0]):
-    tmp = np.zeros(blocksize, dtype=data.dtype)
-    for j in range(indptr[i], indptr[i + 1]):
-      start = indices[j] * blocksize
-      end = start + blocksize
-      tmp += data[start: end] @ vector[start: end]
-    outs[i * blocksize: (i + 1) * blocksize] = tmp
-
-
-# @numba.njit(fastmath=True, parallel=True, nogil=True)
-# def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins):
-#   data, indices, indptr, vector,  blocksize , shape,nnzb,transpose = ins
-#   blocksize = blocksize[()]
-#   outs.fill(0)
-
-#   cnt=0
-#   for i in range(0,shape[0]):
-#       outs.fill(0.0)
-#       tmp=[0.0]*blocksize
-#       for j in range(indptr[i], indptr[i + 1]):
-#         for p in range(0,blocksize):
-#           for q in range(0,blocksize):
-#             tmp[p] += vector[indices[j]*blocksize+q]*data[j*blocksize+p][q]
-#       for j in range(0,blocksize):
-#         outs[cnt] = tmp[j]
-#         cnt+=1
-
-
-def _cusprase_bcsr_matvec_values(values, indices, indptr, vector, *, blocksize, nnzb, shape, transpose):
-  return cusparse_bcsr_matvec(values, indices, indptr, vector, blocksize, nnzb=nnzb, shape=shape, transpose=transpose)
-
-
-def cusparse_bcsr_matvec(
-    data: Union[float, jnp.ndarray],
-    indices: jnp.ndarray,
-    indptr: jnp.ndarray,
-    vector: jnp.ndarray,
-    *,
-    blocksize: int,
-    nnzb: int,
-    shape: Tuple[int, int],
-    method: str = 'vector',
-    transpose: bool = False
-) -> jnp.ndarray:
-  data = as_jax(data)
-  indices = as_jax(indices)
-  indptr = as_jax(indptr)
-  vector = as_jax(vector)
-  if method not in ['scalar', 'vector', 'adaptive']:
-    raise ValueError('Only support methods: scalar, vector, and adaptive. '
-                     f'But we got {method}.')
-
-  data = jnp.atleast_1d(data)
-  if not isinstance(data, jnp.ndarray):
-    raise TypeError(f'data must a ndarray. But we got {type(data)}')
-  if data.dtype not in [jnp.float32, jnp.float64]:
-    raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.')
-  if data.dtype != vector.dtype:
-    raise TypeError('The types of data and vector should be the same. '
-                    f'But we got {data.dtype} != {vector.dtype}.')
-  # assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1
-
-  return cusparse_bcsr_matvec_vector_p.bind(data, indices, indptr, vector, blocksize=blocksize, shape=shape, nnzb=nnzb,
-                                            transpose=transpose)
-
-
-def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb,
-                                                 transpose):
-  inputs = (data, indices, indptr, vector)
-  print(c.get_shape(data))
-  description = dict(blocksize=blocksize, shape=shape, nnzb=nnzb, transpose=transpose, )
-  if transpose:
-    skip = 1
-  else:
-    name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba(
-      c,
-      _cusparse_bcsr_matvec_bsr_matvec_numba_imp,
-      abs_eval_fn=_cusparse_bcsr_matvec_abstract,
-      multiple_results=False,
-      inputs=inputs,
-      description=description
-    )
-  return xla_client.ops.CustomCallWithLayout(
-    c, name,
-    operands=inputs,
-    operand_shapes_with_layout=in_layouts,
-    shape_with_layout=out_layouts,
-  )
-
-
-def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb):
-  gpu_ops = import_brainpylib_gpu_ops()
-  if gpu_ops is None:
-    raise GPUOperatorNotFound(cusparse_bcsr_matvec_vector_p.name)
-
-  data_shape = c.get_shape(data)
-  if data_shape.element_type() == np.float32:
-    type_name = b'float'
-  elif data_shape.element_type() == np.double:
-    type_name = b'double'
-  else:
-    raise ValueError('data_type not support(except float/double)')
-  # 有可能不是这个
-
-  opaque = gpu_ops.build_bcsrcusparsespmv_descriptor(shape[0], shape[1], blocksize, nnzb)
-  return xla_client.ops.CustomCallWithLayout(
-    c,
-    b'gpu_bcsr_cusparse_spmv_' + type_name,
-    operands=(data, indices, indptr, vector),
-    operand_shapes_with_layout=(c.get_shape(data),
-                                c.get_shape(indices),
-                                c.get_shape(indptr),
-                                c.get_shape(vector),
-                                ),
-    shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0] * blocksize,), (0,)),
-    opaque=opaque,
-  )
-
-
-# def _bcsr_matvec_abstract(*args, **kwargs):
-#   data = args[0]
-#   assert len(kwargs) == 1
-#   shape = kwargs['shape']
-#   return ShapedArray(dtype=data.dtype, shape=(shape[0],))
-
-# bcsr_matvec_vector_p = register_op_with_numba(
-#   'bcsr_matvec_vector',
-#   cpu_func=None,
-#   out_shapes=_bcsr_matvec_abstract,
-#   gpu_func_translation=_bcsr_matvec_vector_gpu_translation,
-# )
-
-
-# def _batch_bcsr_matvec_abstract(
-#     values, indices, indptr, vector,block_size, *, shape, transpose=False
-# ):
-#   return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0]))
-
-def _cusparse_bcsr_matvec_abstract(data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose=False):
-  return ShapedArray(dtype=data.dtype, shape=(shape[0] * blocksize,))
-
-
-def _cusparse_bcsr_matvec_jvp_values(data_dot, data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose):
-  return cusparse_bcsr_matvec(data_dot, indices, indptr, vector, blocksize=blocksize, nnzb=nnzb, shape=shape,
-                              transpose=transpose)
-
-
-def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, shape, transpose):
-  if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
-    raise ValueError("Cannot transpose with respect to sparse indices.")
-  if ad.is_undefined_primal(vector):
-    ct_events = cusparse_bcsr_matvec(data, indices, indptr, ct, shape=shape, transpose=not transpose)
-    return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_events)
-  else:
-    if type(ct) is ad.Zero:
-      ct_values = ad.Zero(data)
-    else:
-      row, col = csr_to_coo(indices, indptr)
-      cnt = 0
-      ct_values = []
-      for i in row:
-        for j in col:
-          for p in range(0, blocksize):
-            cntq = 0
-            for q in range(0, blocksize):
-              if transpose:
-                ct_values[cnt][cntq] = vector[i * blocksize + p] * ct[j * blocksize + q]
-              else:
-                ct_values[cnt][cntq] = vector[j * blocksize + q] * ct[i * blocksize + p]
-              cntq += 1
-            cnt += 1
-    return ct_values, indices, indptr, vector
-
-
-cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv')
-cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract)
-cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p))
-# xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation
-# xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation
-ad.defjvp(cusparse_bcsr_matvec_vector_p, _cusparse_bcsr_matvec_jvp_values)
-ad.primitive_transposes[cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_transpose
-register_general_batching(cusparse_bcsr_matvec_vector_p)
-# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule
diff --git a/brainpy/_src/math/sparse/utils.py b/brainpy/_src/math/sparse/utils.py
index f5b74e5eb..38cfdb7b9 100644
--- a/brainpy/_src/math/sparse/utils.py
+++ b/brainpy/_src/math/sparse/utils.py
@@ -1,22 +1,46 @@
 # -*- coding: utf-8 -*-
 
 import warnings
+from functools import partial
 from typing import Tuple
 
 import numpy as np
+from brainpy._src.math.interoperability import as_jax
 from jax import core, numpy as jnp
+from jax import lax
+from jax.interpreters import batching
 from jax.interpreters import mlir, ad
+from jax.tree_util import tree_flatten, tree_unflatten
 from jaxlib import gpu_sparse
 
-from brainpy._src.math.interoperability import as_jax
-from brainpy._src.math.op_register import register_general_batching
-
 __all__ = [
   'coo_to_csr',
   'csr_to_coo',
   'csr_to_dense'
 ]
 
+def _general_batching_rule(prim, args, axes, **kwargs):
+  batch_axes, batch_args, non_batch_args = [], {}, {}
+  for ax_i, ax in enumerate(axes):
+    if ax is None:
+      non_batch_args[f'ax{ax_i}'] = args[ax_i]
+    else:
+      batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
+      batch_axes.append(ax_i)
+
+  def f(_, x):
+    pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
+                  for i in range(len(axes))])
+    return 0, prim.bind(*pars, **kwargs)
+
+  _, outs = lax.scan(f, 0, batch_args)
+  out_vals, out_tree = tree_flatten(outs)
+  out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
+  return outs, out_dim
+
+def _register_general_batching(prim):
+  batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
+
 
 def coo_to_csr(
     pre_ids: jnp.ndarray,
@@ -153,6 +177,6 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape):
 ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None)
 ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose
 mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering)
-register_general_batching(csr_to_dense_p)
+_register_general_batching(csr_to_dense_p)
 if gpu_sparse.cuda_is_supported:
   mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda')
diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py
index 139ec08af..562c1cc18 100644
--- a/brainpy/math/__init__.py
+++ b/brainpy/math/__init__.py
@@ -16,7 +16,6 @@
 
 # operators
 from .pre_syn_post import *
-from .op_register import *
 from . import surrogate, event, sparse, jitconn
 
 # Variable and Objects for object-oriented JAX transformations
diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py
deleted file mode 100644
index 8ec7f5e11..000000000
--- a/brainpy/math/op_register.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# -*- coding: utf-8 -*-
-from brainpy._src.math.op_register import (
-  CustomOpByNumba,
-  compile_cpu_signature_with_numba,
-)
-
-from brainpy._src.math.op_register.base import XLACustomOp
-from brainpy._src.math.op_register.ad_support import defjvp
-
-