diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index dae638e1..35557b60 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -244,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i): # progress bar if self.progress_bar: - jax.pure_callback(lambda *args: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) # return of function monitors shared = dict(t=t + self.dt, dt=self.dt, i=i) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 126ca15c..32358512 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -726,7 +726,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = carry[k] results = body_fun(*x, **unroll_kwargs) if progress_bar: - jax.pure_callback(lambda *arg: bar.update(), ()) + jax.debug.callback(lambda *args: bar.update(), ()) return dyn_vars.dict_data(), results if remat: diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index e801a29e..36ed3c2b 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -219,7 +219,7 @@ def _fun_train(self, targets = target_data[node.name] node.offline_fit(targets, fit_record) if self.progress_bar: - jax.pure_callback(lambda *args: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) def _step_func_monitor(self): res = dict() diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 862db8df..d8e185c3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): # finally if self.progress_bar: - jax.pure_callback(lambda *arg: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) return out, monitors def _check_interface(self):