diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 80609608..73cee508 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -631,7 +631,7 @@ def _step_func_predict(self, i, *x, shared_args=None): # finally if self.progress_bar: - jax.pure_callback(lambda: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) # share.clear_shargs() clear_input(self.target)