Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Dec 7, 2023
2 parents 7d70ef9 + a3263fd commit bff6154
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 113 deletions.
27 changes: 18 additions & 9 deletions brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,19 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import jax
import msgpack
import numpy as np
from jax import monitoring
from jax import process_index
from jax.experimental.multihost_utils import sync_global_devices

try:
from jax import monitoring
except (ModuleNotFoundError, ImportError):
monitoring = None
from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
except:
get_tensorstore_spec = GlobalAsyncCheckpointManager = None

try:
from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
except (ModuleNotFoundError, ImportError):
get_tensorstore_spec = None
GlobalAsyncCheckpointManager = None
import msgpack
except ModuleNotFoundError:
msgpack = None

from brainpy._src.math.ndarray import Array
from brainpy.errors import (AlreadyExistsError,
Expand Down Expand Up @@ -116,6 +114,12 @@ def _record_path(name):
_error_context.path.pop()


def check_msgpack():
if msgpack is None:
raise ModuleNotFoundError('\nbrainpy.checkpoints needs "msgpack" package. Please install msgpack via:\n'
'> pip install msgpack')


def current_path():
"""Current state_dict path during deserialization for error messages."""
return '/'.join(_error_context.path)
Expand Down Expand Up @@ -1126,6 +1130,7 @@ def save(
out: str
Filename of saved checkpoint.
"""
check_msgpack()
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
Expand Down Expand Up @@ -1257,6 +1262,7 @@ def save_pytree(
out: str
Filename of saved checkpoint.
"""
check_msgpack()
if verbose:
print(f'Saving checkpoint into {filename}')
start_time = time.time()
Expand Down Expand Up @@ -1344,6 +1350,7 @@ def multiprocess_save(
out: str
Filename of saved checkpoint.
"""
check_msgpack()
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
Expand Down Expand Up @@ -1488,6 +1495,7 @@ def load(
returned. This is to match the behavior of the case where a directory path
is specified but the directory has not yet been created.
"""
check_msgpack()
start_time = time.time()

ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str
Expand Down Expand Up @@ -1582,6 +1590,7 @@ def load_pytree(
returned. This is to match the behavior of the case where a directory path
is specified but the directory has not yet been created.
"""
check_msgpack()
start_time = time.time()
if not os.path.exists(filename):
raise ValueError(f'Checkpoint not found: {filename}')
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,7 @@ def ifelse(
raise TypeError(msg)
cache_stack(tuple(branches), dyn_vars)
if current_transform_number():
return _if_else_return2(conditions, rets)

return rets[0]
branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches]

code_scope = {'conditions': conditions, 'branches': branches}
Expand Down
25 changes: 25 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,28 @@ def f2():

self.assertTrue(f2().size == 200)

def test_grad1(self):
def F2(x):
return bm.ifelse(conditions=(x >= 10,),
branches=[lambda x: x,
lambda x: x ** 2, ],
operands=x)

self.assertTrue(bm.grad(F2)(9.0) == 18.)
self.assertTrue(bm.grad(F2)(11.0) == 1.)


def test_grad2(self):
def F3(x):
return bm.ifelse(conditions=(x >= 10, x >= 0),
branches=[lambda x: x,
lambda x: x ** 2,
lambda x: x ** 4, ],
operands=x)

self.assertTrue(bm.grad(F3)(9.0) == 18.)
self.assertTrue(bm.grad(F3)(11.0) == 1.)


class TestWhile(unittest.TestCase):
def test1(self):
Expand Down Expand Up @@ -481,3 +503,6 @@ def body(a):
file.seek(0)
out6 = file.read().strip()
self.assertTrue(out5 == out6)



7 changes: 0 additions & 7 deletions brainpy/_src/tools/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,3 @@
For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
'''


brainpylib_install = '''
'''


7 changes: 0 additions & 7 deletions brainpy/_src/tools/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,13 @@


__all__ = [
'import_numba',
'numba_jit',
'numba_seed',
'numba_range',
'SUPPORT_NUMBA',
]


def import_numba():
if numba is None:
raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n'
'> pip install numba')
return numba


SUPPORT_NUMBA = numba is not None

Expand Down
87 changes: 1 addition & 86 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,90 +2,11 @@ BrainPy documentation
=====================

`BrainPy`_ is a highly flexible and extensible framework targeting on the
general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:
general-purpose Brain Dynamics Programming (BDP).

.. _BrainPy: https://github.com/brainpy/BrainPy


Features
^^^^^^^^^

.. grid::

.. grid-item::
:columns: 12 12 12 6

.. card:: OO Transformations
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy supports object-oriented transformations, including
JIT compilation, Autograd.

.. grid-item::
:columns: 12 12 12 6

.. card:: Numerical Integrators
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy provides various numerical integration methods for ODEs, SDEs, DDEs, FDEs, etc.

.. grid-item::
:columns: 12 12 12 6

.. card:: Model Building
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy provides a modular and composable programming interface for building dynamics.

.. grid-item::
:columns: 12 12 12 6

.. card:: Model Simulation
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy supports dynamics simulation for various brain objects with parallel supports.


.. grid-item::
:columns: 12 12 12 6

.. card:: Model Training
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy supports dynamics training with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.

.. grid-item::
:columns: 12 12 12 6

.. card:: Model Analysis
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

BrainPy supports dynamics analysis for low- and high-dimensional systems, including phase plane, bifurcation, linearization, and fixed/slow point analysis.

----

Installation
Expand All @@ -96,24 +17,18 @@ Installation

.. code-block:: bash
pip install -U "jax[cpu]"
pip install -U brainpy brainpylib # windows, linux, macos
.. tab-item:: GPU (CUDA-11x)

.. code-block:: bash
pip install -U "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U brainpy brainpylib-cu11x # only on linux
.. tab-item:: GPU (CUDA-12x)

.. code-block:: bash
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U brainpy brainpylib-cu12x # only on linux
For more information about supported accelerators and platforms, and for other installation details, please see `installation <quickstart/installation.html>`_ section.
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy
jax
tqdm
msgpack
numba
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.8',
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'],
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down

0 comments on commit bff6154

Please sign in to comment.