diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index 35557b60..bf1ecb1f 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -318,8 +318,11 @@ def run( hists = self._run_fun_integration(args, dyn_args, times, indices) if eval_time: running_time = time.time() - t0 - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running times += self.dt diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 73cee508..d98bc58a 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -486,8 +486,11 @@ def predict( running_time = time.time() - t0 # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running for monitors if self._memory_efficient: diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 36ed3c2b..53ff7e56 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -193,8 +193,11 @@ def fit( del monitor_data # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # final things for node in self.train_nodes: diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index d8e185c3..60799cb3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -191,8 +191,11 @@ def fit( outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running for monitors if self.numpy_mon_after_run: