diff --git a/brainpy/_src/math/interoperability.py b/brainpy/_src/math/interoperability.py index 766d4f8e1..22fe25caf 100644 --- a/brainpy/_src/math/interoperability.py +++ b/brainpy/_src/math/interoperability.py @@ -7,7 +7,7 @@ __all__ = [ - 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', + 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', 'is_bp_array' ] @@ -15,6 +15,12 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj +def is_bp_array(x): + """Check if the input is a ``brainpy.math.Array``. + """ + return isinstance(x, Array) + + def as_device_array(tensor, dtype=None): """Convert the input to a ``jax.numpy.DeviceArray``. diff --git a/brainpy/math/interoperability.py b/brainpy/math/interoperability.py index 9bf4aee80..f6356bca7 100644 --- a/brainpy/math/interoperability.py +++ b/brainpy/math/interoperability.py @@ -6,5 +6,6 @@ as_ndarray as as_ndarray, as_numpy as as_numpy, as_variable as as_variable, + is_bp_array as is_bp_array, )