diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index a0d1d535..2fd7df2d 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,6 +1,7 @@ import pytest from absl.testing import absltest from absl.testing import parameterized +import jax.numpy as jnp import brainpy as bp import brainpy.math as bm @@ -100,11 +101,11 @@ def test_CSRLinear(self, conn): bm.random.seed() f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) x = bm.random.random((16, 100)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (16, 100)) x = bm.random.random((100,)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (100,)) bm.clear_buffer_memory() @@ -119,10 +120,10 @@ def test_EventCSRLinear(self, conn): bm.random.seed() f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) x = bm.random.random((16, 100)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (16, 100)) x = bm.random.random((100,)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (100,)) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index de559de5..01159883 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -44,7 +44,7 @@ from .compat_numpy import * from .compat_tensorflow import * from .others import * -from . import random, linalg, fft, tifunc +from . import random, linalg, fft # operators from .op_register import * diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 20a48778..a6dd5a5b 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -9,20 +9,16 @@ from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +is_version_right = False if jax.__version__ >= '0.4.16': from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule - from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, + from braintaichi._primitive._mlir_translation_rule import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) from .cupy_based import ( register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) -else: - from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule - from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, - register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import ( - register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + is_version_right = True + from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -73,6 +69,8 @@ def __init__( outs: Optional[Callable] = None, name: str = None, ): + if not is_version_right: + raise ImportError('XLA Custom Op is only supported in JAX>=0.4.16') super().__init__(name) # set cpu_kernel and gpu_kernel diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 08a070f0..139ec08a 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -33,8 +33,6 @@ from . import linalg from . import random -# taichi operations -from . import tifunc # others from . import sharding diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index f383c1a2..8ec7f5e1 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -2,8 +2,6 @@ from brainpy._src.math.op_register import ( CustomOpByNumba, compile_cpu_signature_with_numba, - clear_taichi_aot_caches, - count_taichi_aot_kernels, ) from brainpy._src.math.op_register.base import XLACustomOp diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py deleted file mode 100644 index bea49c22..00000000 --- a/brainpy/math/tifunc.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.tifunc import ( - - # warp reduction primitives - warp_reduce_sum, - - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand -)