From 12977af4490e81ae0d6d6f6fce2725124c8dd717 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 24 Oct 2024 20:51:12 +0800 Subject: [PATCH 1/6] Add code for processing and computing with progress bar --- brainpy/_src/integrators/runner.py | 7 +++++-- brainpy/_src/runners.py | 7 +++++-- brainpy/_src/train/offline.py | 7 +++++-- brainpy/_src/train/online.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index 35557b60..bf1ecb1f 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -318,8 +318,11 @@ def run( hists = self._run_fun_integration(args, dyn_args, times, indices) if eval_time: running_time = time.time() - t0 - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running times += self.dt diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 73cee508..d98bc58a 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -486,8 +486,11 @@ def predict( running_time = time.time() - t0 # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running for monitors if self._memory_efficient: diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 36ed3c2b..53ff7e56 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -193,8 +193,11 @@ def fit( del monitor_data # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # final things for node in self.train_nodes: diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index d8e185c3..60799cb3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -191,8 +191,11 @@ def fit( outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) # close the progress bar - if self.progress_bar: - self._pbar.close() + + # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), + # we temporarily do not close the progress bar + # if self.progress_bar: + # self._pbar.close() # post-running for monitors if self.numpy_mon_after_run: From 38f222636c631ae89d4ebcc2a2370c14777a94b8 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 25 Oct 2024 16:07:49 +0800 Subject: [PATCH 2/6] fix --- .../_src/math/object_transform/controls.py | 6 ++--- examples/dynamics_simulation/hh_model.py | 24 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 32358512..8861bf6e 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -894,8 +894,8 @@ def for_loop( dyn_vals, out_vals = transform(operands) for key in stack.keys(): stack[key]._value = dyn_vals[key] - if progress_bar: - bar.close() + # if progress_bar: + # bar.close() del dyn_vals, stack return out_vals @@ -915,7 +915,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = dyn_vars_data[k] carry, results = body_fun(carry, x) if progress_bar: - jax.pure_callback(lambda *arg: bar.update(), ()) + jax.debug.callback(lambda *arg: bar.update(), ()) carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) return (dyn_vars.dict_data(), carry), results diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py index 0343ae89..4b3f5f81 100644 --- a/examples/dynamics_simulation/hh_model.py +++ b/examples/dynamics_simulation/hh_model.py @@ -43,16 +43,16 @@ def __init__(self, size): self.KNa.add_elem() -# hh = HH(1) -# I, length = bp.inputs.section_input(values=[0, 5, 0], -# durations=[100, 500, 100], -# return_length=True) -# runner = bp.DSRunner( -# hh, -# monitors=['V', 'INa.p', 'INa.q', 'IK.p'], -# inputs=[hh.input, I, 'iter'], -# ) -# runner.run(length) -# -# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) +hh = HH(1) +I, length = bp.inputs.section_input(values=[0, 5, 0], + durations=[100, 500, 100], + return_length=True) +runner = bp.DSRunner( + hh, + monitors=['V', 'INa.p', 'INa.q', 'IK.p'], + inputs=[hh.input, I, 'iter'], +) +runner.run(length) + +bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) From b8e04d8141095fdb5c0eb25b04660e8227b2f1f1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 25 Oct 2024 16:11:06 +0800 Subject: [PATCH 3/6] Revert "Add code for processing and computing with progress bar" This reverts commit 12977af4490e81ae0d6d6f6fce2725124c8dd717. --- brainpy/_src/integrators/runner.py | 7 ++----- brainpy/_src/runners.py | 7 ++----- brainpy/_src/train/offline.py | 7 ++----- brainpy/_src/train/online.py | 7 ++----- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index bf1ecb1f..35557b60 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -318,11 +318,8 @@ def run( hists = self._run_fun_integration(args, dyn_args, times, indices) if eval_time: running_time = time.time() - t0 - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running times += self.dt diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index d98bc58a..73cee508 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -486,11 +486,8 @@ def predict( running_time = time.time() - t0 # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running for monitors if self._memory_efficient: diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 53ff7e56..36ed3c2b 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -193,11 +193,8 @@ def fit( del monitor_data # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # final things for node in self.train_nodes: diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 60799cb3..d8e185c3 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -191,11 +191,8 @@ def fit( outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) # close the progress bar - - # due to jax 0.4.32 enable the async dispatch(https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0432-september-11-2024), - # we temporarily do not close the progress bar - # if self.progress_bar: - # self._pbar.close() + if self.progress_bar: + self._pbar.close() # post-running for monitors if self.numpy_mon_after_run: From 179269912c014775eeda026c06b389a0637e1642 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 25 Oct 2024 16:16:44 +0800 Subject: [PATCH 4/6] Add JAX configuration for disabling async dispatch --- brainpy/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 837efaf1..e69837fd 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -153,3 +153,8 @@ del deprecation_getattr2 +# jax config +import os +os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' +import jax +jax.config.update('jax_cpu_enable_async_dispatch', False) From 32e485780f951ea6ccd51248dee69e75a96cd906 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 25 Oct 2024 16:37:41 +0800 Subject: [PATCH 5/6] Update controls.py --- brainpy/_src/math/object_transform/controls.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 8861bf6e..ff302333 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -894,8 +894,8 @@ def for_loop( dyn_vals, out_vals = transform(operands) for key in stack.keys(): stack[key]._value = dyn_vals[key] - # if progress_bar: - # bar.close() + if progress_bar: + bar.close() del dyn_vals, stack return out_vals From 5bb3fb91ec9f5f8ea18a574d11517fb0d92b6cf1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 25 Oct 2024 18:27:17 +0800 Subject: [PATCH 6/6] Update brainpy-changelog.md --- brainpy-changelog.md | 74 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/brainpy-changelog.md b/brainpy-changelog.md index c949b701..77e112b5 100644 --- a/brainpy-changelog.md +++ b/brainpy-changelog.md @@ -2,6 +2,80 @@ ## brainpy>2.3.x +### Version 2.6.1 +#### Breaking Changes +- Fixing compatibility issues between `numpy` and `jax` + +#### What's Changed +* [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title by @Routhleck in https://github.com/brainpy/BrainPy/pull/659 +* Fix "amsgrad" is used before being defined when initializing the AdamW optimizer by @CloudyDory in https://github.com/brainpy/BrainPy/pull/660 +* fix issue #661 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/662 +* fix flax RNN interoperation, fix #663 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/665 +* [fix] Replace jax.experimental.host_callback with jax.pure_callback by @Routhleck in https://github.com/brainpy/BrainPy/pull/670 +* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 by @Routhleck in https://github.com/brainpy/BrainPy/pull/669 +* [math] Fix `CustomOpByNumba` on `multiple_results=True` by @Routhleck in https://github.com/brainpy/BrainPy/pull/671 +* [math] Implementing event-driven sparse matrix @ matrix operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/613 +* [math] Add getting JIT connect matrix method for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/672 +* [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/673 +* support `Integrator.to_math_expr()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/674 +* [bug] Replace `collections.Iterable` with `collections.abc.Iterable` by @Routhleck in https://github.com/brainpy/BrainPy/pull/677 +* Fix surrogate gradient function and numpy 2.0 compatibility by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/679 +* :arrow_up: Bump docker/build-push-action from 5 to 6 by @dependabot in https://github.com/brainpy/BrainPy/pull/678 +* fix the incorrect verbose of `clear_name_cache()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/681 +* [bug] Fix prograss bar is not displayed and updated as expected by @Routhleck in https://github.com/brainpy/BrainPy/pull/683 +* Fix autograd by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/687 + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.6.0...V2.6.1 + +### Version 2.6.0 + +#### New Features + +This release provides several new features, including: + +- ``MLIR`` registered operator customization interface in ``brainpy.math.XLACustomOp``. +- Operator customization with CuPy JIT interface. +- Bug fixes. + + + +#### What's Changed +* [doc] Fix the wrong path of more examples of `operator customized with taichi.ipynb` by @Routhleck in https://github.com/brainpy/BrainPy/pull/612 +* [docs] Add colab link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/614 +* Update requirements-doc.txt to fix doc building temporally by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/617 +* [math] Rebase operator customization using MLIR registration interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/618 +* [docs] Add kaggle link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/619 +* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/620 +* require `brainpylib>=0.2.6` for `jax>=0.4.24` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/622 +* [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/624 +* doc hierarchy update by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/630 +* Standardizing and generalizing object-oriented transformations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/628 +* fix #626 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/631 +* Fix delayvar not correct in concat mode by @CloudyDory in https://github.com/brainpy/BrainPy/pull/632 +* [dependency] remove hard dependency of `taichi` and `numba` by @Routhleck in https://github.com/brainpy/BrainPy/pull/635 +* `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/639 +* add `brainpy.math.surrogate..Surrogate` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/638 +* Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/625 +* Fix ci by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/640 +* Clean taichi AOT caches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/643 +* [ci] Fix windows pytest fatal exception by @Routhleck in https://github.com/brainpy/BrainPy/pull/644 +* [math] Support more than 8 parameters of taichi gpu custom operator definition by @Routhleck in https://github.com/brainpy/BrainPy/pull/642 +* Doc for ``brainpylib>=0.3.0`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/645 +* Find back updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/646 +* Update installation instruction by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/651 +* Fix delay bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/650 +* update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/652 +* [math] Add new customize operators with `cupy` by @Routhleck in https://github.com/brainpy/BrainPy/pull/653 +* [math] Fix taichi custom operator on gpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/655 +* update cupy operator custom doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/656 +* version 2.6.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/657 +* Upgrade CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/658 + +## New Contributors +* @CloudyDory made their first contribution in https://github.com/brainpy/BrainPy/pull/632 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.5.0...V2.6.0 ### Version 2.5.0