diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index b435415d..47b81d18 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -98,7 +98,7 @@ def _check_tracer(self): self_value = self.value if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'): if len(self_value._trace.main.jaxpr_stack) == 0: - raise RuntimeError('This Array is modified during the transformation. ' + raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. ' 'BrainPy only supports transformations for Variable. ' 'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None) return self_value diff --git a/brainpy/_src/math/tests/test_ndarray.py b/brainpy/_src/math/tests/test_ndarray.py index a0912912..e9acff35 100644 --- a/brainpy/_src/math/tests/test_ndarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -62,7 +62,7 @@ def _f(self, b): def test_tracing(self): print(self.f(1.)) - with self.assertRaises(RuntimeError): + with self.assertRaises(jax.errors.UnexpectedTracerError): print(self.f(bm.ones(10)))