From 90dbc84855a15efb9a4189724e6f5ec64d4a39ba Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 11 Sep 2023 16:38:27 +0800 Subject: [PATCH 1/2] update installation, code change date, and others --- brainpy/_src/dynsys.py | 4 +- brainpy/_src/math/object_transform/base.py | 2 + brainpy/_src/math/random.py | 13 +- brainpy/_src/mixin.py | 133 +++++++++++---------- brainpy/mixin.py | 2 +- docs/quickstart/installation.rst | 54 +++++++-- 6 files changed, 132 insertions(+), 76 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index a7e7d86d9..770d4bf30 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -10,7 +10,7 @@ from brainpy import tools, math as bm from brainpy._src.initialize import parameter, variable_ -from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data +from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape from brainpy._src.deprecations import _update_deprecate_msg @@ -70,7 +70,7 @@ def update(self, x): return func -class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj): +class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj): """Base Dynamical System class. .. note:: diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index daa8a55bb..061bfe472 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -141,6 +141,8 @@ def fun(self): # that has been created. a = self.tracing_variable('a', bm.zeros, (10,)) + .. versionadded:: 2.4.5 + Args: name: str. The variable name. init: callable, Array. The data to be initialized as a ``Variable``. diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index eb04c5d2e..e989908a0 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -22,7 +22,7 @@ __all__ = [ 'RandomState', 'Generator', 'DEFAULT', - 'seed', 'default_rng', 'split_key', + 'seed', 'default_rng', 'split_key', 'split_keys', # numpy compatibility 'rand', 'randint', 'random_integers', 'randn', 'random', @@ -1258,6 +1258,8 @@ def split_keys(n): internally by `pmap` and `vmap` to ensure that random numbers are different in parallel threads. + .. versionadded:: 2.4.5 + Parameters ---------- n : int @@ -1267,6 +1269,15 @@ def split_keys(n): def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: + """Clone the random state according to the given setting. + + Args: + seed_or_key: The seed (an integer) or the random key. + clone: Bool. Whether clone the default random state. + + Returns: + The random state. + """ if seed_or_key is None: return DEFAULT.clone() if clone else DEFAULT else: diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 124bf3d20..575cc87aa 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -1,6 +1,5 @@ import numbers import sys -import warnings from dataclasses import dataclass from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any from typing import (_SpecialForm, _type_check, _remove_dups_flatten) @@ -46,59 +45,6 @@ class MixIn(object): pass -class ReceiveInputProj(MixIn): - """The :py:class:`~.MixIn` that receives the input projections. - - Note that the subclass should define a ``cur_inputs`` attribute. - - """ - cur_inputs: bm.node_dict - - def add_inp_fun(self, key: Any, fun: Callable): - """Add an input function. - - Args: - key: The dict key. - fun: The function to generate inputs. - """ - if not callable(fun): - raise TypeError('Must be a function.') - if key in self.cur_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.cur_inputs[key] = fun - - def get_inp_fun(self, key): - """Get the input function. - - Args: - key: The key. - - Returns: - The input function which generates currents. - """ - return self.cur_inputs.get(key) - - def sum_inputs(self, *args, init=0., label=None, **kwargs): - """Summarize all inputs by the defined input functions ``.cur_inputs``. - - Args: - *args: The arguments for input functions. - init: The initial input data. - **kwargs: The arguments for input functions. - - Returns: - The total currents. - """ - if label is None: - for key, out in self.cur_inputs.items(): - init = init + out(*args, **kwargs) - else: - for key, out in self.cur_inputs.items(): - if key.startswith(label + ' // '): - init = init + out(*args, **kwargs) - return init - - class ParamDesc(MixIn): """:py:class:`~.MixIn` indicates the function for describing initialization parameters. @@ -207,13 +153,6 @@ def get_data(self): return init -class SupportAutoDelay(MixIn): - """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" - - def return_info(self) -> Union[bm.Variable, ReturnInfo]: - raise NotImplementedError('Must implement the "return_info()" function.') - - class Container(MixIn): """Container :py:class:`~.MixIn` which wrap a group of objects. """ @@ -549,8 +488,71 @@ def get_delay_var(self, name): return global_delay_data[name] +class SupportInputProj(MixIn): + """The :py:class:`~.MixIn` that receives the input projections. + + Note that the subclass should define a ``cur_inputs`` attribute. + + """ + cur_inputs: bm.node_dict + + def add_inp_fun(self, key: Any, fun: Callable): + """Add an input function. + + Args: + key: The dict key. + fun: The function to generate inputs. + """ + if not callable(fun): + raise TypeError('Must be a function.') + if key in self.cur_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.cur_inputs[key] = fun + + def get_inp_fun(self, key): + """Get the input function. + + Args: + key: The key. + + Returns: + The input function which generates currents. + """ + return self.cur_inputs.get(key) + + def sum_inputs(self, *args, init=0., label=None, **kwargs): + """Summarize all inputs by the defined input functions ``.cur_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.cur_inputs.items(): + init = init + out(*args, **kwargs) + else: + for key, out in self.cur_inputs.items(): + if key.startswith(label + ' // '): + init = init + out(*args, **kwargs) + return init + + +class SupportAutoDelay(MixIn): + """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" + + def return_info(self) -> Union[bm.Variable, ReturnInfo]: + raise NotImplementedError('Must implement the "return_info()" function.') + + class SupportOnline(MixIn): - """:py:class:`~.MixIn` to support the online training methods.""" + """:py:class:`~.MixIn` to support the online training methods. + + .. versionadded:: 2.4.5 + """ online_fit_by: Optional # methods for online fitting @@ -562,7 +564,10 @@ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): class SupportOffline(MixIn): - """:py:class:`~.MixIn` to support the offline training methods.""" + """:py:class:`~.MixIn` to support the offline training methods. + + .. versionadded:: 2.4.5 + """ offline_fit_by: Optional # methods for offline fitting @@ -572,6 +577,8 @@ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): class BindCondData(MixIn): """Bind temporary conductance data. + + """ _conductance: Optional diff --git a/brainpy/mixin.py b/brainpy/mixin.py index 82fd9f6ff..e1e79cdc5 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -1,7 +1,7 @@ from brainpy._src.mixin import ( MixIn as MixIn, - ReceiveInputProj as ReceiveInputProj, + SupportInputProj as ReceiveInputProj, AlignPost as AlignPost, SupportAutoDelay as AutoDelaySupp, ParamDesc as ParamDesc, diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index a3f0ce495..e0d5138aa 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -78,8 +78,8 @@ BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend. -Linux & MacOS -^^^^^^^^^^^^^ +Linux +^^^^^ Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS @@ -108,6 +108,7 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + Alternatively, you can download the preferred release ".whl" file for jaxlib from the above release links, and install it via ``pip``: @@ -121,14 +122,46 @@ from the above release links, and install it via ``pip``: Note that the versions of jaxlib and jax should be consistent. - For example, if you are using jax==0.4.15, you would better install -jax==0.4.15. + For example, if you are using jax==0.4.15, you would better install jax==0.4.15. + + +MacOS +^^^^^ + +If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer: + +1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg +2. Then click the downloaded package and install it. + + +If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer: + + +1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg +2. Then click the downloaded package and install it. + + +Finally, you can install `jax` and `jaxlib` as the same as the Linux platform. + +.. code-block:: bash + + pip install --upgrade "jax[cpu]" + + Windows ^^^^^^^ -For **Windows** users, `jax` and `jaxlib` can be installed from the community supports. -Specifically, you can install `jax` and `jaxlib` through: +For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed +directly from the PyPi channel. + +.. code-block:: bash + + pip install jax jaxlib + + +For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed +from the community supports. Specifically, you can install `jax` and `jaxlib` through: .. code-block:: bash @@ -141,7 +174,8 @@ If you are using GPU, you can install GPU-versioned wheels through: pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by -downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html . +downloading binary releases of JAX for Windows from +https://whls.blob.core.windows.net/unstable/index.html . Then install it via ``pip``: .. code-block:: bash @@ -180,8 +214,9 @@ For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators. For CUDA users, ``brainpylib`` only support GPU on Linux platform. You can install GPU version ``brainpylib`` on Linux through ``pip install brainpylib`` too. + Installation from docker -======================== +------------------------ If you want to use BrainPy in docker, you can use the following command to install BrainPy: @@ -190,8 +225,9 @@ to install BrainPy: docker pull ztqakita/brainpy + Running BrainPy online with binder -================================== +---------------------------------- Click on the following link to launch the Binder environment with the BrainPy repository: From d323d3d174299fbc1b8947e91bb0fa4565684b9e Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 11 Sep 2023 21:57:55 +0800 Subject: [PATCH 2/2] update STDP errors --- brainpy/_src/dyn/projections/plasticity.py | 63 +++++----------------- 1 file changed, 13 insertions(+), 50 deletions(-) diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index f3a87aa97..3a3eff608 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -1,61 +1,18 @@ from typing import Optional, Callable, Union -from brainpy.types import ArrayType from brainpy import math as bm, check from brainpy._src.delay import DelayAccess, delay_identifier, init_delay_by_return +from brainpy._src.dyn.synapses.abstract_models import Expon from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo, - AutoDelaySupp, BindCondData, AlignPost, SupportSTDP) from brainpy._src.initialize import parameter -from brainpy._src.dyn.synapses.abstract_models import Expon +from brainpy._src.mixin import (JointType, ParamDescInit, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) +from brainpy.types import ArrayType +from .aligns import _AlignPost, _AlignPreMg, _get_return __all__ = [ - 'STDP_Song2000' + 'STDP_Song2000', ] -class _AlignPre(DynamicalSystem): - def __init__(self, syn, delay=None): - super().__init__() - self.syn = syn - self.delay = delay - - def update(self, x): - if self.delay is None: - return x >> self.syn - else: - return x >> self.syn >> self.delay - - -class _AlignPost(DynamicalSystem): - def __init__(self, - syn: Callable, - out: JointType[DynamicalSystem, BindCondData]): - super().__init__() - self.syn = syn - self.out = out - - def update(self, *args, **kwargs): - self.out.bind_cond(self.syn(*args, **kwargs)) - - -class _AlignPreMg(DynamicalSystem): - def __init__(self, access, syn): - super().__init__() - self.access = access - self.syn = syn - - def update(self, *args, **kwargs): - return self.syn(self.access()) - - -def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError - class STDP_Song2000(Projection): r"""Synaptic output with spike-time-dependent plasticity. @@ -135,9 +92,10 @@ class STDP_Song2000(Projection): A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike. %s """ + def __init__( self, - pre: JointType[DynamicalSystem, AutoDelaySupp], + pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: ParamDescInit[DynamicalSystem], comm: DynamicalSystem, @@ -148,6 +106,7 @@ def __init__( tau_t: Union[float, ArrayType, Callable] = 33.7, A1: Union[float, ArrayType, Callable] = 0.96, A2: Union[float, ArrayType, Callable] = 0.53, + # others out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, @@ -155,7 +114,7 @@ def __init__( super().__init__(name=name, mode=mode) # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(syn, ParamDescInit[DynamicalSystem]) check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) @@ -252,6 +211,7 @@ def calculate_trace( return delay_cls.get_bef_update(_syn_id).syn def update(self): + # pre spikes, and pre-synaptic variables if issubclass(self.syn.cls, AlignPost): pre_spike = self.refs['delay'].at(self.name) x = pre_spike @@ -259,13 +219,16 @@ def update(self): pre_spike = self.refs['delay'].access() x = _get_return(self.refs['syn'].return_info()) + # post spikes post_spike = self.refs['post'].spike + # weight updates Apre = self.refs['pre_trace'].g Apost = self.refs['post_trace'].g delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike) self.comm.update_STDP(delta_w) + # currents current = self.comm(x) if issubclass(self.syn.cls, AlignPost): self.refs['syn'].add_current(current) # synapse post current