Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 14, 2024
1 parent 8094c07 commit c70b3dd
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions brainpy/_src/math/op_register/numba_approach/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,14 @@
numba = import_numba(error_if_not_found=False)
if numba is not None:
from numba import types, carray, cfunc
if numba is not None:
from numba import types, carray, cfunc

__all__ = [
'CustomOpByNumba',
'register_op_with_numba_xla',
'register_op_with_numba_xla',
'compile_cpu_signature_with_numba',
]


def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


def convert_shapedarray_to_shapedtypestruct(shaped_array):
return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype)


def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)

Expand Down Expand Up @@ -86,7 +75,6 @@ def __init__(
if eval_shape is None:
raise ValueError('Must provide "eval_shape" for abstract evaluation.')
self.eval_shape = eval_shape
self.eval_shape = eval_shape

# cpu function
cpu_func = con_compute
Expand Down Expand Up @@ -115,29 +103,6 @@ def __init__(
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)
if jax.__version__ > '0.4.23':
self.op_method = 'mlir'
self.op = register_op_with_numba_mlir(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
gpu_func_translation=None,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)
else:
self.op_method = 'xla'
self.op = register_op_with_numba_xla(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, Array) else a,
Expand Down

0 comments on commit c70b3dd

Please sign in to comment.