diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 32358512..8861bf6e 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -894,8 +894,8 @@ def for_loop( dyn_vals, out_vals = transform(operands) for key in stack.keys(): stack[key]._value = dyn_vals[key] - if progress_bar: - bar.close() + # if progress_bar: + # bar.close() del dyn_vals, stack return out_vals @@ -915,7 +915,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = dyn_vars_data[k] carry, results = body_fun(carry, x) if progress_bar: - jax.pure_callback(lambda *arg: bar.update(), ()) + jax.debug.callback(lambda *arg: bar.update(), ()) carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) return (dyn_vars.dict_data(), carry), results diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py index 0343ae89..4b3f5f81 100644 --- a/examples/dynamics_simulation/hh_model.py +++ b/examples/dynamics_simulation/hh_model.py @@ -43,16 +43,16 @@ def __init__(self, size): self.KNa.add_elem() -# hh = HH(1) -# I, length = bp.inputs.section_input(values=[0, 5, 0], -# durations=[100, 500, 100], -# return_length=True) -# runner = bp.DSRunner( -# hh, -# monitors=['V', 'INa.p', 'INa.q', 'IK.p'], -# inputs=[hh.input, I, 'iter'], -# ) -# runner.run(length) -# -# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) +hh = HH(1) +I, length = bp.inputs.section_input(values=[0, 5, 0], + durations=[100, 500, 100], + return_length=True) +runner = bp.DSRunner( + hh, + monitors=['V', 'INa.p', 'INa.q', 'IK.p'], + inputs=[hh.input, I, 'iter'], +) +runner.run(length) + +bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)