From 2a5adea09aeb1b71e60f2ef31a2f7091deefef92 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Fri, 25 Oct 2024 21:03:56 +0800 Subject: [PATCH] [bug] Fixing compatibility issues with `jax` (#691) * Add code for processing and computing with progress bar * fix * Revert "Add code for processing and computing with progress bar" This reverts commit 12977af4490e81ae0d6d6f6fce2725124c8dd717. * Add JAX configuration for disabling async dispatch * Update controls.py * Update brainpy-changelog.md --------- Co-authored-by: Chaoming Wang --- brainpy-changelog.md | 74 +++++++++++++++++++ brainpy/__init__.py | 5 ++ .../_src/math/object_transform/controls.py | 2 +- examples/dynamics_simulation/hh_model.py | 24 +++--- 4 files changed, 92 insertions(+), 13 deletions(-) diff --git a/brainpy-changelog.md b/brainpy-changelog.md index c949b7010..77e112b5a 100644 --- a/brainpy-changelog.md +++ b/brainpy-changelog.md @@ -2,6 +2,80 @@ ## brainpy>2.3.x +### Version 2.6.1 +#### Breaking Changes +- Fixing compatibility issues between `numpy` and `jax` + +#### What's Changed +* [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title by @Routhleck in https://github.com/brainpy/BrainPy/pull/659 +* Fix "amsgrad" is used before being defined when initializing the AdamW optimizer by @CloudyDory in https://github.com/brainpy/BrainPy/pull/660 +* fix issue #661 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/662 +* fix flax RNN interoperation, fix #663 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/665 +* [fix] Replace jax.experimental.host_callback with jax.pure_callback by @Routhleck in https://github.com/brainpy/BrainPy/pull/670 +* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 by @Routhleck in https://github.com/brainpy/BrainPy/pull/669 +* [math] Fix `CustomOpByNumba` on `multiple_results=True` by @Routhleck in https://github.com/brainpy/BrainPy/pull/671 +* [math] Implementing event-driven sparse matrix @ matrix operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/613 +* [math] Add getting JIT connect matrix method for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/672 +* [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/673 +* support `Integrator.to_math_expr()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/674 +* [bug] Replace `collections.Iterable` with `collections.abc.Iterable` by @Routhleck in https://github.com/brainpy/BrainPy/pull/677 +* Fix surrogate gradient function and numpy 2.0 compatibility by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/679 +* :arrow_up: Bump docker/build-push-action from 5 to 6 by @dependabot in https://github.com/brainpy/BrainPy/pull/678 +* fix the incorrect verbose of `clear_name_cache()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/681 +* [bug] Fix prograss bar is not displayed and updated as expected by @Routhleck in https://github.com/brainpy/BrainPy/pull/683 +* Fix autograd by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/687 + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.6.0...V2.6.1 + +### Version 2.6.0 + +#### New Features + +This release provides several new features, including: + +- ``MLIR`` registered operator customization interface in ``brainpy.math.XLACustomOp``. +- Operator customization with CuPy JIT interface. +- Bug fixes. + + + +#### What's Changed +* [doc] Fix the wrong path of more examples of `operator customized with taichi.ipynb` by @Routhleck in https://github.com/brainpy/BrainPy/pull/612 +* [docs] Add colab link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/614 +* Update requirements-doc.txt to fix doc building temporally by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/617 +* [math] Rebase operator customization using MLIR registration interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/618 +* [docs] Add kaggle link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/619 +* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/620 +* require `brainpylib>=0.2.6` for `jax>=0.4.24` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/622 +* [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/624 +* doc hierarchy update by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/630 +* Standardizing and generalizing object-oriented transformations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/628 +* fix #626 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/631 +* Fix delayvar not correct in concat mode by @CloudyDory in https://github.com/brainpy/BrainPy/pull/632 +* [dependency] remove hard dependency of `taichi` and `numba` by @Routhleck in https://github.com/brainpy/BrainPy/pull/635 +* `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/639 +* add `brainpy.math.surrogate..Surrogate` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/638 +* Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/625 +* Fix ci by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/640 +* Clean taichi AOT caches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/643 +* [ci] Fix windows pytest fatal exception by @Routhleck in https://github.com/brainpy/BrainPy/pull/644 +* [math] Support more than 8 parameters of taichi gpu custom operator definition by @Routhleck in https://github.com/brainpy/BrainPy/pull/642 +* Doc for ``brainpylib>=0.3.0`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/645 +* Find back updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/646 +* Update installation instruction by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/651 +* Fix delay bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/650 +* update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/652 +* [math] Add new customize operators with `cupy` by @Routhleck in https://github.com/brainpy/BrainPy/pull/653 +* [math] Fix taichi custom operator on gpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/655 +* update cupy operator custom doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/656 +* version 2.6.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/657 +* Upgrade CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/658 + +## New Contributors +* @CloudyDory made their first contribution in https://github.com/brainpy/BrainPy/pull/632 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.5.0...V2.6.0 ### Version 2.5.0 diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 837efaf1d..e69837fda 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -153,3 +153,8 @@ del deprecation_getattr2 +# jax config +import os +os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' +import jax +jax.config.update('jax_cpu_enable_async_dispatch', False) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 323585121..ff3023339 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -915,7 +915,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = dyn_vars_data[k] carry, results = body_fun(carry, x) if progress_bar: - jax.pure_callback(lambda *arg: bar.update(), ()) + jax.debug.callback(lambda *arg: bar.update(), ()) carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) return (dyn_vars.dict_data(), carry), results diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py index 0343ae89c..4b3f5f811 100644 --- a/examples/dynamics_simulation/hh_model.py +++ b/examples/dynamics_simulation/hh_model.py @@ -43,16 +43,16 @@ def __init__(self, size): self.KNa.add_elem() -# hh = HH(1) -# I, length = bp.inputs.section_input(values=[0, 5, 0], -# durations=[100, 500, 100], -# return_length=True) -# runner = bp.DSRunner( -# hh, -# monitors=['V', 'INa.p', 'INa.q', 'IK.p'], -# inputs=[hh.input, I, 'iter'], -# ) -# runner.run(length) -# -# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) +hh = HH(1) +I, length = bp.inputs.section_input(values=[0, 5, 0], + durations=[100, 500, 100], + return_length=True) +runner = bp.DSRunner( + hh, + monitors=['V', 'INa.p', 'INa.q', 'IK.p'], + inputs=[hh.input, I, 'iter'], +) +runner.run(length) + +bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)