Skip to content

Commit

Permalink
[bug] Fix prograss bar is not displayed and updated as expected (#683)
Browse files Browse the repository at this point in the history
* Update runners.py

* Update random.py

* Fix
  • Loading branch information
Routhleck authored Aug 3, 2024
1 parent 5e75f78 commit c464025
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
jax.pure_callback(lambda *args: self._pbar.update(), ())
jax.debug.callback(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
jax.pure_callback(lambda *arg: bar.update(), ())
jax.debug.callback(lambda *args: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down
5 changes: 2 additions & 3 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ def _size2shape(size):


def _check_shape(name, shape, *param_shapes):
shape = core.as_named_shape(shape)
if param_shapes:
shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
if shape.positional != shape_:
shape_ = lax.broadcast_shapes(shape, *param_shapes)
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def _step_func_predict(self, i, *x, shared_args=None):

# finally
if self.progress_bar:
jax.pure_callback(lambda: self._pbar.update(), ())
jax.debug.callback(lambda *args: self._pbar.update(), ())
# share.clear_shargs()
clear_input(self.target)

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/train/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _fun_train(self,
targets = target_data[node.name]
node.offline_fit(targets, fit_record)
if self.progress_bar:
jax.pure_callback(lambda *args: self._pbar.update(), ())
jax.debug.callback(lambda *args: self._pbar.update(), ())

def _step_func_monitor(self):
res = dict()
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):

# finally
if self.progress_bar:
jax.pure_callback(lambda *arg: self._pbar.update(), ())
jax.debug.callback(lambda *args: self._pbar.update(), ())
return out, monitors

def _check_interface(self):
Expand Down

0 comments on commit c464025

Please sign in to comment.