diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 4dd176519..74fa1188c 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -227,3 +227,4 @@ def _transform_to_array(a): def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) +