diff --git a/brainunit/_base.py b/brainunit/_base.py index c1ca3e7..47a4d0f 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4468,7 +4468,7 @@ def new_f(*args, **kwds): elif specific_unit == 1: if isinstance(v, Quantity): newkeyset[n] = v.to_decimal() - elif isinstance(v, jax.typing.ArrayLike): + elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)): newkeyset[n] = v else: specific_unit = jax.typing.ArrayLike @@ -4502,7 +4502,7 @@ def new_f(*args, **kwds): elif specific_unit == 1: if isinstance(result, Quantity): result = result.to_decimal() - elif isinstance(result, jax.typing.ArrayLike): + elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)): result = jnp.asarray(result) else: specific_unit = jax.typing.ArrayLike