diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index dae638e1..35557b60 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -244,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i): # progress bar if self.progress_bar: - jax.pure_callback(lambda *args: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) # return of function monitors shared = dict(t=t + self.dt, dt=self.dt, i=i) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 126ca15c..32358512 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -726,7 +726,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = carry[k] results = body_fun(*x, **unroll_kwargs) if progress_bar: - jax.pure_callback(lambda *arg: bar.update(), ()) + jax.debug.callback(lambda *args: bar.update(), ()) return dyn_vars.dict_data(), results if remat: diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 74190cb2..3f3a8446 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -67,10 +67,9 @@ def _size2shape(size): def _check_shape(name, shape, *param_shapes): - shape = core.as_named_shape(shape) if param_shapes: - shape_ = lax.broadcast_shapes(shape.positional, *param_shapes) - if shape.positional != shape_: + shape_ = lax.broadcast_shapes(shape, *param_shapes) + if shape != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.") diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 80609608..73cee508 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -631,7 +631,7 @@ def _step_func_predict(self, i, *x, shared_args=None): # finally if self.progress_bar: - jax.pure_callback(lambda: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) # share.clear_shargs() clear_input(self.target) diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index e801a29e..36ed3c2b 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -219,7 +219,7 @@ def _fun_train(self, targets = target_data[node.name] node.offline_fit(targets, fit_record) if self.progress_bar: - jax.pure_callback(lambda *args: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) def _step_func_monitor(self): res = dict() diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 862db8df..d8e185c3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): # finally if self.progress_bar: - jax.pure_callback(lambda *arg: self._pbar.update(), ()) + jax.debug.callback(lambda *args: self._pbar.update(), ()) return out, monitors def _check_interface(self):