From f58fe5b82df07c841822033328ee36f9e25ff18c Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 3 Nov 2023 21:21:16 +0800 Subject: [PATCH 1/3] [math] add `__array_priority__` --- brainpy/_src/math/ndarray.py | 2 ++ .../math/tests/{test_jaxarray.py => test_ndarray.py} | 11 +++++++++++ 2 files changed, 13 insertions(+) rename brainpy/_src/math/tests/{test_jaxarray.py => test_ndarray.py} (89%) diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 0c9bf8f54..b5d12d9ce 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1518,6 +1518,8 @@ def float(self): return jnp.asarray(self.value, dtype=jnp.float32) def double(self): return jnp.asarray(self.value, dtype=jnp.float64) +setattr(Array, "__array_priority__", 100) + JaxArray = Array ndarray = Array diff --git a/brainpy/_src/math/tests/test_jaxarray.py b/brainpy/_src/math/tests/test_ndarray.py similarity index 89% rename from brainpy/_src/math/tests/test_jaxarray.py rename to brainpy/_src/math/tests/test_ndarray.py index 9a227a071..09a6f791c 100644 --- a/brainpy/_src/math/tests/test_jaxarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -111,3 +111,14 @@ def test_update(self): ) self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10)) + + +class TestArrayPriority(unittest.TestCase): + def test1(self): + a = bm.Array(bm.zeros(10)) + assert isinstance(a + bm.ones(1).value, bm.Array) + assert isinstance(a + np.ones(1), bm.Array) + assert isinstance(a * np.ones(1), bm.Array) + assert isinstance(np.ones(1) + a, bm.Array) + assert isinstance(np.ones(1) * a, bm.Array) + From c8cb9d5dd5bfe24da76a962dcdee1bcced817a95 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 3 Nov 2023 21:21:38 +0800 Subject: [PATCH 2/3] [math] change `spk_type` to `spk_dtype` --- brainpy/_src/dyn/neurons/base.py | 10 ++--- brainpy/_src/dyn/neurons/lif.py | 76 ++++++++++++++++---------------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py index 02a457d0a..264ce8865 100644 --- a/brainpy/_src/dyn/neurons/base.py +++ b/brainpy/_src/dyn/neurons/base.py @@ -29,7 +29,7 @@ def __init__( scaling: Optional[bm.Scaling] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, ): @@ -43,18 +43,18 @@ def __init__( self.spk_reset = spk_reset self.spk_fun = is_callable(spk_fun) self.detach_spk = detach_spk - self._spk_type = spk_type + self._spk_dtype = spk_dtype if scaling is None: self.scaling = bm.get_membrane_scaling() else: self.scaling = scaling @property - def spk_type(self): - if self._spk_type is None: + def spk_dtype(self): + if self._spk_dtype is None: return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_ else: - return self._spk_type + return self._spk_dtype def offset_scaling(self, x, bias=None, scale=None): s = self.scaling.offset_scaling(x, bias=bias, scale=scale) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 018ad24a9..988c915ac 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -77,7 +77,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -99,7 +99,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -124,7 +124,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -206,7 +206,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -230,7 +230,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -257,7 +257,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -399,7 +399,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, detach_spk: bool = False, spk_reset: str = 'soft', method: str = 'exp_auto', @@ -429,7 +429,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -673,7 +673,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -699,7 +699,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -730,7 +730,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -1001,7 +1001,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, detach_spk: bool = False, spk_reset: str = 'soft', method: str = 'exp_auto', @@ -1033,7 +1033,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -1343,7 +1343,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -1373,7 +1373,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -1416,7 +1416,7 @@ def derivative(self): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -1672,7 +1672,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -1708,7 +1708,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -1991,7 +1991,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2017,7 +2017,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -2046,7 +2046,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -2255,7 +2255,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2287,7 +2287,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -2554,7 +2554,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2583,7 +2583,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -2624,7 +2624,7 @@ def derivative(self): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -2856,7 +2856,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2891,7 +2891,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -3201,7 +3201,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3237,7 +3237,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -3291,7 +3291,7 @@ def reset_state(self, batch_size=None, **kwargs): self.V_th = self.offset_scaling(self.init_variable(self._Vth_initializer, batch_size)) self.I1 = self.std_scaling(self.init_variable(self._I1_initializer, batch_size)) self.I2 = self.std_scaling(self.init_variable(self._I2_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -3581,7 +3581,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3623,7 +3623,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -3952,7 +3952,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3982,7 +3982,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -4031,7 +4031,7 @@ def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.V) self.u = self.offset_scaling(self.init_variable(self._u_initializer, batch_size), bias=self.b * self.scaling.bias, scale=self.scaling.scale) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -4266,7 +4266,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -4302,7 +4302,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, From 48f77ee08bbb3a79d1bbc1f035d91e10af87cfc3 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 4 Nov 2023 16:10:13 +0800 Subject: [PATCH 3/3] [math] new abstract function for `XLACustomOp` --- brainpy/_src/math/op_register/base.py | 26 +++++++++---------- brainpy/_src/math/op_register/numba_based.py | 7 +++-- .../_src/math/op_register/taichi_aot_based.py | 10 +++---- brainpy/_src/math/tests/test_ndarray.py | 5 ++++ 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ad240a6a8..779b5aa2d 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -22,7 +22,6 @@ preprocess_kernel_call_cpu, ) from .utils import register_general_batching - __all__ = [ 'XLACustomOp', ] @@ -92,7 +91,7 @@ def __init__( name: str = None, ): super().__init__(name) - + # set cpu_kernel and gpu_kernel self.cpu_kernel = cpu_kernel self.gpu_kernel = gpu_kernel @@ -105,7 +104,7 @@ def __init__( if outs is not None: outs = tuple([_transform_to_shapedarray(o) for o in outs]) self.outs = outs - self.primitive.def_abstract_eval(self._abstract_eval) + self.primitive.def_abstract_eval(_abstract_eval) self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function @@ -142,15 +141,11 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation - def _abstract_eval(self, *args, **kwargs): - if self.outs is None: - raise ValueError('"self.outs" must be defined, but got None.') - return self.outs - def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None): - # _set_taichi_envir() - if outs is not None: - self.outs = tuple([_transform_to_shapedarray(o) for o in outs]) + if outs is None: + outs = self.outs + assert outs is not None + outs = tuple([_transform_to_shapedarray(o) for o in outs]) cpu_kernel = getattr(self, "cpu_kernel", None) if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi source_md5_encode = encode_md5('cpu' + inspect.getsource(cpu_kernel) + \ @@ -160,7 +155,7 @@ def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None): new_ins.extend(ins) ins = new_ins ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) - return self.primitive.bind(*ins) + return self.primitive.bind(*ins, outs=outs) def def_abstract_eval(self, fun): """Define the abstract evaluation function. @@ -213,6 +208,11 @@ def def_mlir_lowering(self, platform, fun): mlir.register_lowering(self.primitive, fun, platform) +def _abstract_eval(*args, **kwargs): + return [jax.core.ShapedArray(out_shape.shape, out_shape.dtype) + for out_shape in kwargs['outs']] + + def _is_bp_array(a): return isinstance(a, Array) @@ -229,6 +229,7 @@ def _transform_to_array(a): def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) + def _set_taichi_envir(): # find the path of taichi in python site_packages taichi_path = ti.__path__[0] @@ -238,4 +239,3 @@ def _set_taichi_envir(): 'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir, 'TI_LIB_DIR': taichi_lib_dir }) - diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index 4f801be8d..fb51b5dbf 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -66,8 +66,8 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): return target_name -def _numba_xla_cpu_translation_rule(prim, kernel, debug: bool, c, *ins): - outs = prim.abstract_eval()[0] +def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): + outs = kwargs['outs'] # output information output_shapes = tuple(out.shape for out in outs) @@ -101,12 +101,11 @@ def _numba_xla_cpu_translation_rule(prim, kernel, debug: bool, c, *ins): def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, - primitive, cpu_kernel, debug) -def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins): +def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): # output information outs = ctx.avals_out output_shapes = tuple([out.shape for out in outs]) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 328252845..75bc34087 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -354,8 +354,8 @@ def _taichi_cpu_translation_rule(prim, kernel, c, *ins): ) -def _taichi_gpu_translation_rule(prim, kernel, c, *ins): - outs = prim.abstract_eval()[0] +def _taichi_gpu_translation_rule(kernel, c, *ins, **kwargs): + outs = kwargs['outs'] output_shapes = tuple(out.shape for out in outs) output_dtypes = tuple(out.dtype for out in outs) @@ -401,10 +401,8 @@ def _taichi_gpu_translation_rule(prim, kernel, c, *ins): def register_taichi_cpu_translation_rule(primitive, cpu_kernel): - xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_cpu_translation_rule, - primitive, cpu_kernel) + xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_cpu_translation_rule, cpu_kernel) def register_taichi_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_gpu_translation_rule, - primitive, gpu_kernel) + xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_gpu_translation_rule, gpu_kernel) diff --git a/brainpy/_src/math/tests/test_ndarray.py b/brainpy/_src/math/tests/test_ndarray.py index 09a6f791c..a09129129 100644 --- a/brainpy/_src/math/tests/test_ndarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -121,4 +121,9 @@ def test1(self): assert isinstance(a * np.ones(1), bm.Array) assert isinstance(np.ones(1) + a, bm.Array) assert isinstance(np.ones(1) * a, bm.Array) + b = bm.Variable(bm.zeros(10)) + assert isinstance(b + bm.ones(1).value, bm.Array) + assert isinstance(b + np.ones(1), bm.Array) + assert isinstance(np.ones(1) + b, bm.Array) + assert isinstance(np.ones(1) * b, bm.Array)