Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 11, 2023
2 parents d323d3d + 90dbc84 commit 46bd161
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 78 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
from brainpy._src.deprecations import _update_deprecate_msg
Expand Down Expand Up @@ -70,7 +70,7 @@ def update(self, x):
return func


class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj):
class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj):
"""Base Dynamical System class.
.. note::
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def fun(self):
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))
.. versionadded:: 2.4.5
Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
Expand Down
13 changes: 12 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = [
'RandomState', 'Generator', 'DEFAULT',

'seed', 'default_rng', 'split_key',
'seed', 'default_rng', 'split_key', 'split_keys',

# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
Expand Down Expand Up @@ -1258,6 +1258,8 @@ def split_keys(n):
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
.. versionadded:: 2.4.5
Parameters
----------
n : int
Expand All @@ -1267,6 +1269,15 @@ def split_keys(n):


def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
"""Clone the random state according to the given setting.
Args:
seed_or_key: The seed (an integer) or the random key.
clone: Bool. Whether clone the default random state.
Returns:
The random state.
"""
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
Expand Down
138 changes: 74 additions & 64 deletions brainpy/_src/mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numbers
import sys
import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any
from typing import (_SpecialForm, _type_check, _remove_dups_flatten)
Expand Down Expand Up @@ -28,12 +27,15 @@
'ParamDesc',
'ParamDescInit',
'AlignPost',
'SupportAutoDelay',
'Container',
'TreeNode',
'BindCondData',
'JointType',
'SupportSTDP',
'SupportAutoDelay',
'SupportInputProj',
'SupportOnline',
'SupportOffline',
]

global_delay_data = dict()
Expand All @@ -47,59 +49,6 @@ class MixIn(object):
pass


class ReceiveInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.
Note that the subclass should define a ``cur_inputs`` attribute.
"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.
Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.
Args:
key: The key.
Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.
Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.
Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class ParamDesc(MixIn):
""":py:class:`~.MixIn` indicates the function for describing initialization parameters.
Expand Down Expand Up @@ -208,13 +157,6 @@ def get_data(self):
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class Container(MixIn):
"""Container :py:class:`~.MixIn` which wrap a group of objects.
"""
Expand Down Expand Up @@ -550,8 +492,71 @@ def get_delay_var(self, name):
return global_delay_data[name]


class SupportInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.
Note that the subclass should define a ``cur_inputs`` attribute.
"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.
Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.
Args:
key: The key.
Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.
Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.
Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class SupportOnline(MixIn):
""":py:class:`~.MixIn` to support the online training methods."""
""":py:class:`~.MixIn` to support the online training methods.
.. versionadded:: 2.4.5
"""

online_fit_by: Optional # methods for online fitting

Expand All @@ -563,7 +568,10 @@ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):


class SupportOffline(MixIn):
""":py:class:`~.MixIn` to support the offline training methods."""
""":py:class:`~.MixIn` to support the offline training methods.
.. versionadded:: 2.4.5
"""

offline_fit_by: Optional # methods for offline fitting

Expand All @@ -573,6 +581,8 @@ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):

class BindCondData(MixIn):
"""Bind temporary conductance data.
"""
_conductance: Optional

Expand Down
4 changes: 2 additions & 2 deletions brainpy/mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

from brainpy._src.mixin import (
MixIn as MixIn,
ReceiveInputProj as ReceiveInputProj,
SupportInputProj as SupportInputProj,
AlignPost as AlignPost,
AutoDelaySupp as AutoDelaySupp,
SupportAutoDelay as SupportAutoDelay,
ParamDesc as ParamDesc,
ParamDescInit as ParamDescInit,
BindCondData as BindCondData,
Expand Down
54 changes: 45 additions & 9 deletions docs/quickstart/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables
users to run Python code on CPU, GPU, and TPU devices. Core functionalities of
BrainPy (>=2.0.0) have been migrated to the JAX backend.

Linux & MacOS
^^^^^^^^^^^^^
Linux
^^^^^

Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or
later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS
Expand Down Expand Up @@ -108,6 +108,7 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Alternatively, you can download the preferred release ".whl" file for jaxlib
from the above release links, and install it via ``pip``:

Expand All @@ -121,14 +122,46 @@ from the above release links, and install it via ``pip``:

Note that the versions of jaxlib and jax should be consistent.

For example, if you are using jax==0.4.15, you would better install
jax==0.4.15.
For example, if you are using jax==0.4.15, you would better install jax==0.4.15.


MacOS
^^^^^

If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer:

1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg
2. Then click the downloaded package and install it.


If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer:


1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg
2. Then click the downloaded package and install it.


Finally, you can install `jax` and `jaxlib` as the same as the Linux platform.

.. code-block:: bash
pip install --upgrade "jax[cpu]"
Windows
^^^^^^^

For **Windows** users, `jax` and `jaxlib` can be installed from the community supports.
Specifically, you can install `jax` and `jaxlib` through:
For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed
directly from the PyPi channel.

.. code-block:: bash
pip install jax jaxlib
For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed
from the community supports. Specifically, you can install `jax` and `jaxlib` through:

.. code-block:: bash
Expand All @@ -141,7 +174,8 @@ If you are using GPU, you can install GPU-versioned wheels through:
pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html
Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by
downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html .
downloading binary releases of JAX for Windows from
https://whls.blob.core.windows.net/unstable/index.html .
Then install it via ``pip``:

.. code-block:: bash
Expand Down Expand Up @@ -180,8 +214,9 @@ For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators.
For CUDA users, ``brainpylib`` only support GPU on Linux platform. You can install GPU version ``brainpylib``
on Linux through ``pip install brainpylib`` too.


Installation from docker
========================
------------------------

If you want to use BrainPy in docker, you can use the following command
to install BrainPy:
Expand All @@ -190,8 +225,9 @@ to install BrainPy:
docker pull ztqakita/brainpy
Running BrainPy online with binder
==================================
----------------------------------

Click on the following link to launch the Binder environment with the
BrainPy repository:
Expand Down

0 comments on commit 46bd161

Please sign in to comment.