diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index a7e7d86d9..770d4bf30 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -10,7 +10,7 @@ from brainpy import tools, math as bm from brainpy._src.initialize import parameter, variable_ -from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data +from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape from brainpy._src.deprecations import _update_deprecate_msg @@ -70,7 +70,7 @@ def update(self, x): return func -class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj): +class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj): """Base Dynamical System class. .. note:: diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index daa8a55bb..061bfe472 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -141,6 +141,8 @@ def fun(self): # that has been created. a = self.tracing_variable('a', bm.zeros, (10,)) + .. versionadded:: 2.4.5 + Args: name: str. The variable name. init: callable, Array. The data to be initialized as a ``Variable``. diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index eb04c5d2e..e989908a0 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -22,7 +22,7 @@ __all__ = [ 'RandomState', 'Generator', 'DEFAULT', - 'seed', 'default_rng', 'split_key', + 'seed', 'default_rng', 'split_key', 'split_keys', # numpy compatibility 'rand', 'randint', 'random_integers', 'randn', 'random', @@ -1258,6 +1258,8 @@ def split_keys(n): internally by `pmap` and `vmap` to ensure that random numbers are different in parallel threads. + .. versionadded:: 2.4.5 + Parameters ---------- n : int @@ -1267,6 +1269,15 @@ def split_keys(n): def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: + """Clone the random state according to the given setting. + + Args: + seed_or_key: The seed (an integer) or the random key. + clone: Bool. Whether clone the default random state. + + Returns: + The random state. + """ if seed_or_key is None: return DEFAULT.clone() if clone else DEFAULT else: diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index d533e3df4..23cd703bf 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -1,6 +1,5 @@ import numbers import sys -import warnings from dataclasses import dataclass from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any from typing import (_SpecialForm, _type_check, _remove_dups_flatten) @@ -28,12 +27,15 @@ 'ParamDesc', 'ParamDescInit', 'AlignPost', - 'SupportAutoDelay', 'Container', 'TreeNode', 'BindCondData', 'JointType', 'SupportSTDP', + 'SupportAutoDelay', + 'SupportInputProj', + 'SupportOnline', + 'SupportOffline', ] global_delay_data = dict() @@ -47,59 +49,6 @@ class MixIn(object): pass -class ReceiveInputProj(MixIn): - """The :py:class:`~.MixIn` that receives the input projections. - - Note that the subclass should define a ``cur_inputs`` attribute. - - """ - cur_inputs: bm.node_dict - - def add_inp_fun(self, key: Any, fun: Callable): - """Add an input function. - - Args: - key: The dict key. - fun: The function to generate inputs. - """ - if not callable(fun): - raise TypeError('Must be a function.') - if key in self.cur_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.cur_inputs[key] = fun - - def get_inp_fun(self, key): - """Get the input function. - - Args: - key: The key. - - Returns: - The input function which generates currents. - """ - return self.cur_inputs.get(key) - - def sum_inputs(self, *args, init=0., label=None, **kwargs): - """Summarize all inputs by the defined input functions ``.cur_inputs``. - - Args: - *args: The arguments for input functions. - init: The initial input data. - **kwargs: The arguments for input functions. - - Returns: - The total currents. - """ - if label is None: - for key, out in self.cur_inputs.items(): - init = init + out(*args, **kwargs) - else: - for key, out in self.cur_inputs.items(): - if key.startswith(label + ' // '): - init = init + out(*args, **kwargs) - return init - - class ParamDesc(MixIn): """:py:class:`~.MixIn` indicates the function for describing initialization parameters. @@ -208,13 +157,6 @@ def get_data(self): return init -class SupportAutoDelay(MixIn): - """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" - - def return_info(self) -> Union[bm.Variable, ReturnInfo]: - raise NotImplementedError('Must implement the "return_info()" function.') - - class Container(MixIn): """Container :py:class:`~.MixIn` which wrap a group of objects. """ @@ -550,8 +492,71 @@ def get_delay_var(self, name): return global_delay_data[name] +class SupportInputProj(MixIn): + """The :py:class:`~.MixIn` that receives the input projections. + + Note that the subclass should define a ``cur_inputs`` attribute. + + """ + cur_inputs: bm.node_dict + + def add_inp_fun(self, key: Any, fun: Callable): + """Add an input function. + + Args: + key: The dict key. + fun: The function to generate inputs. + """ + if not callable(fun): + raise TypeError('Must be a function.') + if key in self.cur_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.cur_inputs[key] = fun + + def get_inp_fun(self, key): + """Get the input function. + + Args: + key: The key. + + Returns: + The input function which generates currents. + """ + return self.cur_inputs.get(key) + + def sum_inputs(self, *args, init=0., label=None, **kwargs): + """Summarize all inputs by the defined input functions ``.cur_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.cur_inputs.items(): + init = init + out(*args, **kwargs) + else: + for key, out in self.cur_inputs.items(): + if key.startswith(label + ' // '): + init = init + out(*args, **kwargs) + return init + + +class SupportAutoDelay(MixIn): + """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" + + def return_info(self) -> Union[bm.Variable, ReturnInfo]: + raise NotImplementedError('Must implement the "return_info()" function.') + + class SupportOnline(MixIn): - """:py:class:`~.MixIn` to support the online training methods.""" + """:py:class:`~.MixIn` to support the online training methods. + + .. versionadded:: 2.4.5 + """ online_fit_by: Optional # methods for online fitting @@ -563,7 +568,10 @@ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): class SupportOffline(MixIn): - """:py:class:`~.MixIn` to support the offline training methods.""" + """:py:class:`~.MixIn` to support the offline training methods. + + .. versionadded:: 2.4.5 + """ offline_fit_by: Optional # methods for offline fitting @@ -573,6 +581,8 @@ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): class BindCondData(MixIn): """Bind temporary conductance data. + + """ _conductance: Optional diff --git a/brainpy/mixin.py b/brainpy/mixin.py index b2c03793f..ab3c3cd37 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -1,9 +1,9 @@ from brainpy._src.mixin import ( MixIn as MixIn, - ReceiveInputProj as ReceiveInputProj, + SupportInputProj as SupportInputProj, AlignPost as AlignPost, - AutoDelaySupp as AutoDelaySupp, + SupportAutoDelay as SupportAutoDelay, ParamDesc as ParamDesc, ParamDescInit as ParamDescInit, BindCondData as BindCondData, diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index a3f0ce495..e0d5138aa 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -78,8 +78,8 @@ BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend. -Linux & MacOS -^^^^^^^^^^^^^ +Linux +^^^^^ Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS @@ -108,6 +108,7 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + Alternatively, you can download the preferred release ".whl" file for jaxlib from the above release links, and install it via ``pip``: @@ -121,14 +122,46 @@ from the above release links, and install it via ``pip``: Note that the versions of jaxlib and jax should be consistent. - For example, if you are using jax==0.4.15, you would better install -jax==0.4.15. + For example, if you are using jax==0.4.15, you would better install jax==0.4.15. + + +MacOS +^^^^^ + +If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer: + +1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg +2. Then click the downloaded package and install it. + + +If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer: + + +1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg +2. Then click the downloaded package and install it. + + +Finally, you can install `jax` and `jaxlib` as the same as the Linux platform. + +.. code-block:: bash + + pip install --upgrade "jax[cpu]" + + Windows ^^^^^^^ -For **Windows** users, `jax` and `jaxlib` can be installed from the community supports. -Specifically, you can install `jax` and `jaxlib` through: +For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed +directly from the PyPi channel. + +.. code-block:: bash + + pip install jax jaxlib + + +For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed +from the community supports. Specifically, you can install `jax` and `jaxlib` through: .. code-block:: bash @@ -141,7 +174,8 @@ If you are using GPU, you can install GPU-versioned wheels through: pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by -downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html . +downloading binary releases of JAX for Windows from +https://whls.blob.core.windows.net/unstable/index.html . Then install it via ``pip``: .. code-block:: bash @@ -180,8 +214,9 @@ For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators. For CUDA users, ``brainpylib`` only support GPU on Linux platform. You can install GPU version ``brainpylib`` on Linux through ``pip install brainpylib`` too. + Installation from docker -======================== +------------------------ If you want to use BrainPy in docker, you can use the following command to install BrainPy: @@ -190,8 +225,9 @@ to install BrainPy: docker pull ztqakita/brainpy + Running BrainPy online with binder -================================== +---------------------------------- Click on the following link to launch the Binder environment with the BrainPy repository: