Skip to content

Commit

Permalink
Merge branch 'master' into add-matmat-op
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 25, 2024
2 parents 68e25f9 + e3a854a commit 23ef2fc
Show file tree
Hide file tree
Showing 45 changed files with 1,598 additions and 237 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -79,6 +80,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -127,6 +129,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install numpy>=1.21.0
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
Expand Down
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-


__version__ = "2.5.0"
__version__ = "2.6.0"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
66 changes: 60 additions & 6 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
'raise_taichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
'import_cupy_jit',
'raise_cupy_not_found',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
Expand All @@ -17,14 +20,19 @@

numba = None
taichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0')
taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n'
'> pip install taichi -U')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
os.environ["TI_LOG_LEVEL"] = "error"


Expand All @@ -49,9 +57,13 @@ def import_taichi(error_if_not_found=True):

if taichi is None:
return None
if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi
taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2]
minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \
_minimal_taichi_version[2]
if taichi_version >= minimal_taichi_version:
return taichi
else:
raise ModuleNotFoundError(taichi_install_info)


def raise_taichi_not_found(*args, **kwargs):
Expand Down Expand Up @@ -81,6 +93,48 @@ def raise_numba_not_found():
raise ModuleNotFoundError(numba_install_info)


def import_cupy(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy
if cupy is None:
try:
import cupy as cupy
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy


def import_cupy_jit(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy_jit
if cupy_jit is None:
try:
from cupyx import jit as cupy_jit
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy_jit


def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)


def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True

Expand Down
19 changes: 11 additions & 8 deletions brainpy/_src/dnn/interoperation_flax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import jax
import dataclasses
from typing import Dict
from typing import Dict, Tuple
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from brainpy import math as bm
Expand Down Expand Up @@ -77,16 +77,16 @@ class ToFlaxRNNCell(RNNCellBase):
model: DynamicalSystem
train_params: Dict[str, jax.Array] = dataclasses.field(init=False)

def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
if len(batch_dims) == 0:
def initialize_carry(self, rng, input_shape: Tuple[int, ...]):
batch_dims = input_shape[:-1]
if len(batch_dims) == 1:
batch_dims = 1
elif len(batch_dims) == 1:
batch_dims = batch_dims[0]
elif len(batch_dims) == 0:
batch_dims = None
else:
raise NotImplementedError

raise ValueError(f'Invalid input shape: {input_shape}')
_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
self.model.reset(batch_size=batch_dims)
self.model.reset(batch_dims)
return [_state_vars.dict(), 0, 0.]

def setup(self):
Expand Down Expand Up @@ -131,6 +131,9 @@ def __call__(self, carry, *inputs):
# carray and output
return [_state_vars.dict(), i + 1, t + share.dt], out

@property
def num_feature_axes(self) -> int:
return 1

else:
class ToFlaxRNNCell(object):
Expand Down
44 changes: 44 additions & 0 deletions brainpy/_src/dnn/tests/test_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import pytest

pytest.skip('Skip this test because it is not implemented yet.',
allow_module_level=True)

import jax
import jax.numpy as jnp
import flax.linen as nn

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bm.set_mode(bm.training_mode)

cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1, ))


class myRNN(nn.Module):
@nn.compact
def __call__(self, x): # x:(batch, time, features)
x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell
return x


def test_init():
model = myRNN()
model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
layer=bp.dnn.NVAR(num_in=5,
layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/initialize/tests/test_decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
cmap = plt.cm.get_cmap('coolwarm')
plt.cm.get_cmap('coolwarm')
cmap = plt.colormaps.get_cmap('coolwarm')
plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
Expand Down
6 changes: 6 additions & 0 deletions brainpy/_src/integrators/ode/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@
"""

from functools import wraps

import jax.numpy as jnp

from brainpy import errors
from brainpy._src import math as bm
from brainpy._src.integrators import constants as C, utils, joint_eq
Expand Down Expand Up @@ -356,6 +359,9 @@ def _build_integrator(self, eq):
# integration function
def integral(*args, **kwargs):
assert len(args) > 0
if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.'
f'But we got {args[0].dtype}.')
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
phi = bm.exprel(dt * linear)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):

return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
def update(self):
t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten

from brainpy import math as bm
Expand Down Expand Up @@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
12 changes: 5 additions & 7 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,25 @@ def searchsorted(self, v, side='left', sorter=None):
"""
return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter))

def sort(self, axis=-1, kind='quicksort', order=None):
def sort(self, axis=-1, stable=True, order=None):
"""Sort an array in-place.
Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
and 'mergesort' use timsort under the covers and, in general, the
actual implementation will vary with datatype. The 'mergesort' option
is retained for backwards compatibility.
stable : bool, optional
Whether to use a stable sorting algorithm. The default is True.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
be specified as a string, and not all fields need be specified,
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
self.value = self.value.sort(axis=axis, kind=kind, order=order)
self.value = self.value.sort(axis=axis, stable=stable, order=order)


def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
Expand Down
Loading

0 comments on commit 23ef2fc

Please sign in to comment.