diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index bf1ecb1f..35557b60 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -318,11 +318,8 @@ def run( hists = self._run_fun_integration(args, dyn_args, times, indices) if eval_time: running_time = time.time() - t0 - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running times += self.dt diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index d98bc58a..73cee508 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -486,11 +486,8 @@ def predict( running_time = time.time() - t0 # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running for monitors if self._memory_efficient: diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 53ff7e56..36ed3c2b 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -193,11 +193,8 @@ def fit( del monitor_data # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # final things for node in self.train_nodes: diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 60799cb3..d8e185c3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -191,11 +191,8 @@ def fit( outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running for monitors if self.numpy_mon_after_run: