From c70b3dd2bfd7ccd5e491745b3250a5261f250e1d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:21:47 +0800 Subject: [PATCH] Update __init__.py --- .../op_register/numba_approach/__init__.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 5ac07191..4d9b284f 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -21,25 +21,14 @@ numba = import_numba(error_if_not_found=False) if numba is not None: from numba import types, carray, cfunc -if numba is not None: - from numba import types, carray, cfunc __all__ = [ 'CustomOpByNumba', 'register_op_with_numba_xla', - 'register_op_with_numba_xla', 'compile_cpu_signature_with_numba', ] -def _transform_to_shapedarray(a): - return jax.core.ShapedArray(a.shape, a.dtype) - - -def convert_shapedarray_to_shapedtypestruct(shaped_array): - return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype) - - def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) @@ -86,7 +75,6 @@ def __init__( if eval_shape is None: raise ValueError('Must provide "eval_shape" for abstract evaluation.') self.eval_shape = eval_shape - self.eval_shape = eval_shape # cpu function cpu_func = con_compute @@ -115,29 +103,6 @@ def __init__( transpose_translation=transpose_translation, multiple_results=multiple_results, ) - if jax.__version__ > '0.4.23': - self.op_method = 'mlir' - self.op = register_op_with_numba_mlir( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - gpu_func_translation=None, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - else: - self.op_method = 'xla' - self.op = register_op_with_numba_xla( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, Array) else a,