diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 9438c3306..0e4ee769c 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- + from functools import partial from typing import Optional @@ -9,10 +10,8 @@ from brainpy.errors import ConnectorError from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed from brainpy._src.tools.package import SUPPORT_NUMBA -from brainpy._src.dependency_check import import_numba from .base import * -numba = import_numba(error_if_not_found=False) __all__ = [ 'FixedProb', @@ -1099,171 +1098,43 @@ def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True, rng = np.random if SUPPORT_NUMBA else self.rng - # @njit(parallel=True) - # def _connect_1d_jit_parallel(pre_pos, pre_size, post_size, n_dim): - # all_post_ids = np.zeros(post_size[0], dtype=get_idx_type()) - # all_pre_ids = np.zeros(post_size[0], dtype=get_idx_type()) - # size = 0 - # - # if rng.random() < pre_ratio: - # normalized_pos = np.zeros(n_dim) - # for i in prange(n_dim): # Use prange for potential parallelism - # pre_len = pre_size[i] - # post_len = post_size[i] - # normalized_pos[i] = pre_pos[i] * post_len / pre_len - # for i in prange(post_size[0]): - # post_pos = np.asarray((i,)) - # d = np.abs(pre_pos[0] - post_pos[0]) # Adjust the distance calculation - # if d <= dist: - # if d == 0. and not include_self: - # continue - # if rng.random() <= prob: - # all_post_ids[size] = pos2ind(post_pos, post_size) - # all_pre_ids[size] = pos2ind(pre_pos, pre_size) - # size += 1 - # return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays - - if numba is not None: - from numba import njit - @njit - def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim): - all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) - all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - post_pos = np.asarray((i,)) - d = np.abs(pre_pos[0] - post_pos[0]) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - @njit - def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - post_pos = np.asarray((i, j)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays - - @njit - def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] * post_size[2] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - for k in range(post_size[2]): - post_pos = np.asarray((i, j, k)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - @njit - def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - for k in range(post_size[2]): - for l in range(post_size[3]): - post_pos = np.asarray((i, j, k, l)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - self._connect_1d_jit = _connect_1d_jit - self._connect_2d_jit = _connect_2d_jit - self._connect_3d_jit = _connect_3d_jit - self._connect_4d_jit = _connect_4d_jit - - def _connect_1d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + @numba_jit + def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim): + all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): post_pos = np.asarray((i,)) - d = np.sum(np.abs(pre_pos - post_pos)) + d = np.abs(pre_pos[0] - post_pos[0]) if d <= dist: if d == 0. and not include_self: continue if rng.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 - def _connect_2d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): post_pos = np.asarray((i, j)) @@ -1271,20 +1142,25 @@ def _connect_2d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids - - def _connect_3d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays + + @numba_jit + def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): for k in range(post_size[2]): @@ -1293,20 +1169,25 @@ def _connect_3d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids - - def _connect_4d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): for k in range(post_size[2]): @@ -1316,15 +1197,16 @@ def _connect_4d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] - self._connect_1d = numba_jit(_connect_1d) - self._connect_2d = numba_jit(_connect_2d) - self._connect_3d = numba_jit(_connect_3d) - self._connect_4d = numba_jit(_connect_4d) + self._connect_1d_jit = _connect_1d_jit + self._connect_2d_jit = _connect_2d_jit + self._connect_3d_jit = _connect_3d_jit + self._connect_4d_jit = _connect_4d_jit def build_coo(self, isOptimized=True): if len(self.pre_size) != len(self.post_size): @@ -1336,41 +1218,16 @@ def build_coo(self, isOptimized=True): # connections n_dim = len(self.pre_size) - if not isOptimized: - if n_dim == 1: - f = self._connect_1d - elif n_dim == 2: - f = self._connect_2d - elif n_dim == 3: - f = self._connect_3d - elif n_dim == 4: - f = self._connect_4d - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') + if n_dim == 1: + f = self._connect_1d_jit + elif n_dim == 2: + f = self._connect_2d_jit + elif n_dim == 3: + f = self._connect_3d_jit + elif n_dim == 4: + f = self._connect_4d_jit else: - if numba is None: - if n_dim == 1: - f = self._connect_1d - elif n_dim == 2: - f = self._connect_2d - elif n_dim == 3: - f = self._connect_3d - elif n_dim == 4: - f = self._connect_4d - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') - else: - if n_dim == 1: - f = self._connect_1d_jit - elif n_dim == 2: - f = self._connect_2d_jit - elif n_dim == 3: - f = self._connect_3d_jit - elif n_dim == 4: - f = self._connect_4d_jit - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') - + raise NotImplementedError('Does not support the network dimension bigger than 4.') pre_size = np.asarray(self.pre_size) post_size = np.asarray(self.post_size) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 2babb5023..2820c7081 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,4 +1,3 @@ -import functools import os import sys @@ -7,12 +6,8 @@ __all__ = [ 'import_taichi', 'raise_taichi_not_found', - 'check_taichi_func', - 'check_taichi_class', 'import_numba', 'raise_numba_not_found', - 'check_numba_func', - 'check_numba_class', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] @@ -20,8 +15,8 @@ _minimal_brainpylib_version = '0.2.6' _minimal_taichi_version = (1, 7, 0) -taichi = None numba = None +taichi = None brainpylib_cpu_ops = None brainpylib_gpu_ops = None @@ -29,8 +24,7 @@ f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' '> pip install taichi==1.7.0') numba_install_info = ('We need numba. Please install numba by pip . \n' - '> pip install numba' - ) + '> pip install numba') os.environ["TI_LOG_LEVEL"] = "error" @@ -55,30 +49,10 @@ def import_taichi(error_if_not_found=True): return taichi -def raise_taichi_not_found(): +def raise_taichi_not_found(*args, **kwargs): raise ModuleNotFoundError(taichi_install_info) -def check_taichi_func(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if taichi is None: - raise_taichi_not_found() - return func(*args, **kwargs) - - return wrapper - - -def check_taichi_class(cls): - class Wrapper(cls): - def __init__(self, *args, **kwargs): - if taichi is None: - raise_taichi_not_found() - super().__init__(*args, **kwargs) - - return Wrapper - - def import_numba(error_if_not_found=True): global numba if numba is None: @@ -96,26 +70,6 @@ def raise_numba_not_found(): raise ModuleNotFoundError(numba_install_info) -def check_numba_func(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if numba is None: - raise_numba_not_found() - return func(*args, **kwargs) - - return wrapper - - -def check_numba_class(cls): - class Wrapper(cls): - def __init__(self, *args, **kwargs): - if numba is None: - raise_numba_not_found() - super().__init__(*args, **kwargs) - - return Wrapper - - def is_brainpylib_gpu_installed(): return False if brainpylib_gpu_ops is None else True diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index e4b6e25d2..deead1f3b 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = bm.unsqueeze(x, 0) + x = x.unsqueeze(0) w = self.w.value if self.mask is not None: try: @@ -190,9 +190,6 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. - The input should a 2d array with the shape of ``[H, C]``, or - a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - Parameters ---------- in_channels: int @@ -285,9 +282,6 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, C]``, or - a 4d array with the shape of ``[B, H, W, C]``. - Parameters ---------- in_channels: int @@ -381,9 +375,6 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, D, C]``, or - a 4d array with the shape of ``[B, H, W, D, C]``. - Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 7a92bc8b2..c524fb0bf 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -11,17 +11,16 @@ from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share +from brainpy._src.dependency_check import import_taichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_numba, import_taichi, check_numba_func, check_taichi_func from brainpy.check import is_initializer from brainpy.connect import csr2csc -from brainpy.errors import MathError +from brainpy.errors import MathError, PackageMissingError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding ti = import_taichi(error_if_not_found=False) -numba = import_numba(error_if_not_found=False) __all__ = [ 'Dense', 'Linear', @@ -239,152 +238,108 @@ def update(self, x): return x -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - -dense_on_pre_prim = None if ti is not None: + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + @ti.kernel - def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] + def _dense_on_post( + old_w: ti.types.ndarray(ndim=2), + post_spike: ti.types.ndarray(ndim=1), + pre_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if post_spike[j]: + new_value = out_w[i, j] + pre_trace[i] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) @ti.kernel - def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _dense_on_pre( + old_w: ti.types.ndarray(ndim=2), + pre_spike: ti.types.ndarray(ndim=1), + post_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if pre_spike[i]: + new_value = out_w[i, j] + post_trace[j] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] - dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, - gpu_kernel=_gpu_dense_on_pre) + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) + +else: + dense_on_pre_prim = None + dense_on_post_prim = None -@check_taichi_func def dense_on_pre(weight, spike, trace, w_min, w_max): + if dense_on_pre_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - -dense_on_post_prim = None -if ti is not None: - @ti.kernel - def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - - - @ti.kernel - def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - - - dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, - gpu_kernel=_gpu_dense_on_post) - - -@check_taichi_func def dense_on_post(weight, spike, trace, w_min, w_max): + if dense_on_post_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) - if dense_on_post_prim is None: - import_taichi() return dense_on_post_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -756,107 +711,168 @@ def _batch_csrmv(self, x): transpose=self.transpose) -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): -# out_w[:] = w -# w_min = w_min[()] -# w_max = w_max[()] -# for i in numba.prange(spike.shape[0]): # pre id -# if spike[i]: -# for k in range(indptr[i], indptr[i + 1]): # synapse id -# j = indices[k] # post id -# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) -# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - -csr_on_pre_update_prim = None if ti is not None: @ti.kernel - def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _csr_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) + spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre = spike.shape[0] + for i_pre in range(num_pre): + if spike[i_pre]: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) + else: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = old_w[i_syn] + + + csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + + + @ti.kernel + def _coo_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if pre_spike[pre_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + + + @ti.kernel + def _coo_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if post_spike[post_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + # out_w[:] = w + # w_min = w_min[()] + # w_max = w_max[()] + # for i in numba.prange(spike.shape[0]): # post id + # if spike[i]: + # for k in range(indptr[i], indptr[i + 1]): + # j = post_ids[k] # pre id + # l = w_ids[k] # syn id + # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + @ti.kernel - def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _csc_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) + w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + num_post = post_spike.shape[0] + for i_post in range(num_post): + if post_spike[i_post]: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) + else: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = old_w[i_syn] + + + csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) - csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, - gpu_kernel=_gpu_csr_on_pre_update) +else: + csr_on_pre_update_prim = None + coo_on_pre_update_prim = None + csc_on_post_update_prim = None -@check_taichi_func def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if csr_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) - if csr_on_pre_update_prim is None: - import_taichi() return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -csc_on_pre_update_prim = None -if numba is not None: - @numba.njit(nogil=True, fastmath=True, parallel=False) - def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) +def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): + if coo_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) +def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): + if csc_on_post_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') -@check_numba_func -def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] class CSCLinear(Layer): diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 17054667d..ba2a49efd 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,6 +1,5 @@ -from absl.testing import absltest from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f523622..3c9fdfa87 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp +from unittest import TestCase from absl.testing import absltest +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import parameterized - import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): + bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -22,7 +24,6 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -30,6 +31,7 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): + bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -37,7 +39,6 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -53,7 +54,6 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,7 +67,6 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 9ad15938d..269fec441 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- +from unittest import TestCase + +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized - import brainpy as bp -import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index de2c9765b..fdc5b34e3 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,8 +1,7 @@ -from absl.testing import absltest +import brainpy.math as bm from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp -import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 5748edd8b..34f8f5cd5 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest from absl.testing import parameterized +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 7ffc4e763..18d9d9dc9 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -import pytest import numpy as np +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -18,7 +17,7 @@ class Test_STDP(parameterized.TestCase): @parameterized.product( - comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], + comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], delay=[None, 0., 2.], syn_model=['exp', 'dual_exp', 'ampa'], out_model=['cuba', 'coba', 'mg'] @@ -102,9 +101,11 @@ def update(self, I_pre, I_post): duration = 300. I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, + duration - 255]) I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, + duration - 250]) net = STDPNet(1, 1) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 676e4286b..eb8e27c8f 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import broadcast_to, expand_dims, concatenate +from .compat_numpy import vstack, broadcast_to from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,7 +392,6 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: - self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -473,7 +472,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) + self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index bac809388..6b7f7da02 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -17,7 +17,7 @@ import numpy as np from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi @@ -30,7 +30,6 @@ ti = import_taichi(error_if_not_found=False) -@check_taichi_func def csrmv( data: Union[float, jax.Array], indices: jax.Array, @@ -45,48 +44,6 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - Parameters ---------- data: ndarray, float @@ -164,7 +121,7 @@ def raw_csrmv_taichi( transpose: bool = False ): if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') if transpose: if events.dtype == jnp.bool_: diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index b2de30b01..976b72b96 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -6,10 +6,11 @@ import numpy as np from jax import numpy as jnp -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, + mv_prob_normal, _general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, @@ -30,64 +31,8 @@ 'event_mv_prob_normal', ] -@check_taichi_func -def event_mv_prob_homo( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - -@check_taichi_func -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - -@check_taichi_func -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_homo_taichi( +def event_mv_prob_homo( events: jax.Array, weight: float, conn_prob: float, @@ -97,56 +42,8 @@ def event_mv_prob_homo_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) @@ -163,7 +60,10 @@ def event_mv_prob_homo_taichi( outdim_parallel=outdim_parallel)[0] -def event_mv_prob_uniform_taichi( +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( events: jax.Array, w_low: float, w_high: float, @@ -174,58 +74,8 @@ def event_mv_prob_uniform_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) @@ -242,7 +92,10 @@ def event_mv_prob_uniform_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def event_mv_prob_normal_taichi( +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( events: jax.Array, w_mu: float, w_sigma: float, @@ -253,58 +106,8 @@ def event_mv_prob_normal_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) @@ -321,9 +124,12 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] +event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ + if ti is not None: from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + # ------------- # CPU function # ------------- diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 894294c79..00e5778f9 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -8,7 +8,7 @@ from jax import numpy as jnp from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp @@ -23,48 +23,6 @@ ] -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -@check_taichi_func def mv_prob_homo( vector: Union[Array, jax.Array], weight: float, @@ -123,11 +81,24 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -@check_taichi_func def mv_prob_uniform( vector: jax.Array, w_low: float, @@ -189,11 +160,24 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -@check_taichi_func def mv_prob_normal( vector: jax.Array, w_mu: float, @@ -255,235 +239,8 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') vector = as_jax(vector) if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) @@ -585,6 +342,47 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) + else: + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + + return shape, out_shape + + +def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + def _mv_prob_homo_transpose( ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel ): @@ -644,6 +442,7 @@ def _mv_prob_normal_transpose( assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + def _reverse(shape): return shape[::-1] @@ -774,9 +573,6 @@ def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) @@ -936,9 +732,6 @@ def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_uniform_jvp_vector, @@ -1103,9 +896,6 @@ def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, out transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_normal_jvp_vector, diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index ad8a5ccf6..f5e091675 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,8 +28,10 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, VariableStack) -from .tools import eval_shape +from .variables import (Variable, + VariableStack, + current_transform_number, + new_transform) __all__ = [ 'grad', # gradient of scalar function @@ -201,21 +203,36 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) + with new_transform(self): + with VariableStack() as stack: + if current_transform_number() > 1: + rets = self._transform( + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) + else: + rets = jax.eval_shape( + self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if not stack.is_first_stack(): - return self._return(rets) + # if not the outermost transformation + if current_transform_number(): + return self._return(rets) + else: + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index c52845a06..aaf053ae7 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,6 +6,7 @@ """ import numbers +import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -13,13 +14,14 @@ import jax import numpy as np -from brainpy._src.math.modes import Mode +from brainpy import errors from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) +from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3edeb08e8..032a0fab6 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,12 +21,17 @@ cache_stack ) from .tools import ( - eval_shape, + evaluate_dyn_vars, dynvar_deprecation, node_deprecation, abstract ) -from .variables import (Variable, VariableStack) +from .variables import ( + Variable, + VariableStack, + new_transform, + current_transform_number, +) __all__ = [ 'make_loop', @@ -537,13 +542,15 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit and dyn_vars is None: - with VariableStack() as dyn_vars: - rets = eval_shape(true_fun, *operands, with_stack=True)[1] - _ = eval_shape(false_fun, *operands, with_stack=True) - cache_stack((true_fun, false_fun), dyn_vars) - if not dyn_vars.is_first_stack(): - return rets + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('cond'): + dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((true_fun, false_fun), dyn_vars) + if current_transform_number() > 0: + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -674,16 +681,20 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with VariableStack() as dyn_vars: - rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with new_transform('ifelse'): + with VariableStack() as dyn_vars: + if current_transform_number() > 1: + rets = [branch(*operands) for branch in branches] + else: + rets = [jax.eval_shape(branch, *operands) for branch in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if not dyn_vars.is_first_stack(): + if current_transform_number(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -869,23 +880,28 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - stack = get_stack_cache((body_fun, unroll_kwargs)) + dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if stack is None: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, - remat, reverse, unroll, unroll_kwargs) + if dyn_vars is None: # TODO: better cache mechanism? - with VariableStack() as stack: - rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), stack) # cache - if not stack.is_first_stack(): + with new_transform('for_loop'): + with VariableStack() as dyn_vars: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, + progress_bar, remat, reverse, unroll, + unroll_kwargs) + if current_transform_number() > 1: + rets = transform(operands) + else: + rets = jax.eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache + if current_transform_number(): return rets[1] del rets else: - stack = VariableStack() + dyn_vars = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, stack, bar, + transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -893,11 +909,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, stack + del dyn_vals, dyn_vars return out_vals @@ -995,21 +1011,26 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - stack = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit and stack is None: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as stack: - rets = eval_shape(transform, init, operands) - cache_stack(body_fun, stack) # cache - if not stack.is_first_stack(): - return rets[0][1], rets[1] - del rets - - stack = VariableStack() if stack is None else stack - transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) + dyn_vars = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('scan'): + with VariableStack() as dyn_vars: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + if current_transform_number() > 1: + rets = transform(init, operands) + else: + rets = jax.eval_shape(transform, init, operands) + cache_stack(body_fun, dyn_vars) # cache + if current_transform_number(): + return rets[0][1], rets[1] + del rets + + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1108,6 +1129,7 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. + """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1115,16 +1137,18 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - stack = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit and stack is None: - with VariableStack() as stack: - _ = eval_shape(cond_fun, *operands, with_stack=True) - rets = eval_shape(body_fun, *operands, with_stack=True)[1] - cache_stack((body_fun, cond_fun), stack) - if not stack.is_first_stack(): - return rets - stack = VariableStack() if stack is None else stack - dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) - for k, v in stack.items(): + dyn_vars = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('while_loop'): + dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((body_fun, cond_fun), dyn_vars) + if current_transform_number(): + return rets + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) + for k, v in dyn_vars.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 73eab2f91..7bb36f4e2 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,15 +11,23 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax +from jax.sharding import Sharding from brainpy import tools, check -from .base import BrainPyObject, ObjectTransform -from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - eval_shape) -from .variables import (Variable, VariableStack) + evaluate_dyn_vars_with_cache, + evaluate_dyn_vars, + _partial_fun) +from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from ..ndarray import Array +from .variables import (Variable, + VariableStack, + outermost_transform, + transform_stack, + current_transform_number, + new_transform) RandomState = None @@ -143,12 +151,16 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with VariableStack() as self._dyn_vars: - rets = eval_shape(self.fun, - *args, - **kwargs, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames) + with new_transform(self): + self._dyn_vars, rets = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + use_eval_shape=current_transform_number() <= 1, + **kwargs + ) + # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +186,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -195,7 +207,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if not self._dyn_vars.is_first_stack(): + if current_transform_number(): return rets # call the transformed function @@ -465,8 +477,15 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - with VariableStack() as stack: - _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + + with jax.ensure_compile_time_eval(): + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 + with VariableStack() as stack: + _ = jax.eval_shape(fun3, *args_, **kwargs_) + del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1181e003b..1c8ca6ef9 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=True): +def clear_name_cache(ignore_warn=False): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,7 +57,6 @@ def cache_stack(func, stack): def clear_stack_cache(): - """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py new file mode 100644 index 000000000..1eddce048 --- /dev/null +++ b/brainpy/_src/math/object_transform/parallels.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- + +""" +The parallel compilation tools for JAX backend. + +1. Vectorize compilation is implemented by the 'vmap()' function +2. Parallel compilation is implemented by the 'pmap()' function + +""" + + +import functools + +import jax +import jax.numpy as jnp +import numpy as np +from jax.interpreters.partial_eval import DynamicJaxprTracer +from jax.interpreters.partial_eval import JaxprTracer +from jax.interpreters.pxla import ShardedDeviceArray + +try: + from jax.errors import UnexpectedTracerError +except ImportError: + from jax.core import UnexpectedTracerError + +from brainpy import errors +from brainpy._src.math.random import RandomState +from brainpy._src.math.ndarray import Array +from brainpy.tools import change_func_name +from .base import BrainPyObject, ArrayCollector + +__all__ = [ + 'vmap', + 'pmap', +] + + +def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, + batch_idx, axis_name, f_name=None): + @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + out = func(*args, **kwargs) + nonbatched_changes = nonbatched_vars.dict() + batched_changes = batched_vars.dict() + return nonbatched_changes, batched_changes, out + + def call(*args, **kwargs): + n = args[batch_idx[0]].shape[batch_idx[1]] + nonbatched_data = nonbatched_vars.dict() + batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} + try: + out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) + except UnexpectedTracerError as e: + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + raise errors.JaxTracerError() from e + # for key, v in dyn_changes.items(): + # dyn_vars[key] = reduce_func(v) + # for key, v in rand_changes.items(): + # rand_vars[key] = reduce_func(v) + return out + + return change_func_name(name=f_name, f=call) if f_name else call + + +def vmap(func, dyn_vars=None, batched_vars=None, + in_axes=0, out_axes=0, axis_name=None, + reduce_func=None, auto_infer=False): + """Vectorization compilation for class objects. + + Vectorized compile a function or a module to run in parallel on a single device. + + Examples + -------- + + Parameters + ---------- + func : BrainPyObject, function, callable + The function or the module to compile. + dyn_vars : dict, sequence + batched_vars : dict + in_axes : optional, int, sequence of int + Specify which input array axes to map over. If each positional argument to + ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, + or a tuple of integers and Nones with length equal to the number of + positional arguments to ``obj_or_func``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + dimensions (axes) of the corresponding input array. + + If the positional arguments to ``obj_or_func`` are container types, the + corresponding element of ``in_axes`` can itself be a matching container, + so that distinct array axes can be mapped for different container + elements. ``in_axes`` must be a container tree prefix of the positional + argument tuple passed to ``obj_or_func``. + + At least one positional argument must have ``in_axes`` not None. The sizes + of the mapped input axes for all mapped positional arguments must all be + equal. + + Arguments passed as keywords are always mapped over their leading axis + (i.e. axis index 0). + out_axes : optional, int, tuple/list/dict + Indicate where the mapped axis should appear in the output. All outputs + with a mapped axis must have a non-None ``out_axes`` specification. Axis + integers must be in the range ``[-ndim, ndim)`` for each output array, + where ``ndim`` is the number of dimensions (axes) of the array returned + by the :func:`vmap`-ed function, which is one more than the number of + dimensions (axes) of the corresponding array returned by ``obj_or_func``. + axis_name : optional + + Returns + ------- + obj_or_func : Any + Batched/vectorized version of ``obj_or_func`` with arguments that correspond to + those of ``obj_or_func``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but + with extra array axes at positions indicated by ``out_axes``. + + """ + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # dyn_vars = (dyn_vars or func.vars().unique()) + # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() + # for key, val in dyn_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # in axes + # if in_axes is None: + # in_axes = {key: (None, 0) for key in func.steps.keys()} + # elif isinstance(in_axes, int): + # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, (tuple, list)): + # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in in_axes: + # in_axes = {key: (None, 0, in_axes) for key in keys} + # else: + # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} + # assert isinstance(in_axes, dict) + # + # # batch size index + # batch_idx = {} + # for key, axes in in_axes.items(): + # for i, axis in enumerate(axes[2:]): + # if axis is not None: + # batch_idx[key] = (i, axis) + # break + # else: + # raise ValueError(f'Found no batch axis: {axes}.') + # + # # out axes + # if out_axes is None: + # out_axes = {key: 0 for key in func.steps.keys()} + # elif isinstance(out_axes, int): + # out_axes = {key: out_axes for key in func.steps.keys()} + # elif isinstance(out_axes, (tuple, list)): + # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} + # elif isinstance(out_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in out_axes: + # out_axes = {key: (out_axes, 0, 0) for key in keys} + # else: + # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} + # assert isinstance(out_axes, dict) + # + # # reduce_func + # if reduce_func is None: + # reduce_func = lambda x: x.mean(axis=0) + # + # # vectorized map functions + # for key in func.steps.keys(): + # func.steps[key] = _make_vmap(func=func.steps[key], + # dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # in_axes=in_axes[key], + # out_axes=out_axes[key], + # axis_name=axis_name, + # batch_idx=batch_idx[key], + # reduce_func=reduce_func, + # f_name=key) + # + # return func + + if callable(func): + if auto_infer: + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.vmap(func, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name) + + else: + if isinstance(dyn_vars, Array): + dyn_vars = [dyn_vars] + if isinstance(dyn_vars, (tuple, list)): + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} + assert isinstance(dyn_vars, dict) + + # dynamical variables + _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + _rand_vars[key] = val + else: + _dyn_vars[key] = val + + # in axes + if in_axes is None: + in_axes = (None, 0) + elif isinstance(in_axes, (int, dict)): + in_axes = (None, 0, in_axes) + elif isinstance(in_axes, (tuple, list)): + in_axes = (None, 0) + tuple(in_axes) + assert isinstance(in_axes, (tuple, list)) + + # batch size index + batch_idx = {} + for key, axes in batch_idx.items(): + for i, axis in enumerate(axes[2:]): + if axis is not None: + batch_idx[key] = (i, axis) + break + else: + raise ValueError(f'Found no batch axis: {axes}.') + + # out axes + if out_axes is None: + out_axes = 0 + elif isinstance(out_axes, (int, dict)): + out_axes = (out_axes, 0, 0) + elif isinstance(out_axes, (tuple, list)): + out_axes = tuple(out_axes) + (0, 0) + assert isinstance(out_axes, (list, tuple)) + + # reduce_func + if reduce_func is None: + reduce_func = lambda x: x.mean(axis=0) + + # jit function + return _make_vmap(func=func, + nonbatched_vars=_dyn_vars, + batched_vars=_rand_vars, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + batch_idx=batch_idx) + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' + f'function, but we got {type(func)}.') + + +def _device_reshape(x): + """Reshape an input array in order to broadcast to multiple devices.""" + num_device = jax.local_device_count() + + if not hasattr(x, 'ndim'): + raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' + f'parallel, first convert it to a Array, for example np.float(0.5)') + if x.ndim == 0: + return np.broadcast_to(x, [num_device]) + if x.shape[0] % num_device != 0: + raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' + f'{num_device} devices, but does not go equally.') + return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) + + +def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, + out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): + @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, + backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + def pmapped_func(dyn_data, rand_data, *args, **kwargs): + dyn_vars.assign(dyn_data) + rand_vars.assign(rand_data) + out = func(*args, **kwargs) + dyn_changes = dyn_vars.dict() + rand_changes = rand_vars.dict() + return out, dyn_changes, rand_changes + + def call(*args): + un_replicated = [k for k, v in dyn_vars.items() + if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] + if len(un_replicated): + raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' + f'did you forget to call xx.replicate() on them?') + _args = [] + for i, x in enumerate(args): + if i + 2 in static_broadcasted_argnums: + _args.append(x) + else: + _args.append(jax.tree_map(_device_reshape, [x])[0]) + dyn_data = dyn_vars.dict() + rand_data = rand_vars.dict() + output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) + dyn_vars.assign(dyn_changes) + rand_vars.assign(rand_changes) + return jax.tree_map(reduce_func, output) + + return change_func_name(name=f_name, f=call) if f_name else call + + +def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), + devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, + reduce_func=None): + """Parallel compilation for class objects. + + Parallel compile a function or a module to run on multiple devices in parallel. + + Parameters + ---------- + func + axis_name + in_axes + out_axes + static_broadcasted_argnums + devices + backend + axis_size + donate_argnums + global_arg_shapes + + Returns + ------- + + + Examples + -------- + + + """ + + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # all_vars = (dyn_vars or func.vars().unique()) + # dyn_vars = ArrayCollector() + # rand_vars = ArrayCollector() + # for key, val in all_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # reduce function + # if reduce_func is None: + # reduce_func = jnp.concatenate + # + # # static broadcast-ed arguments + # if static_broadcasted_argnums is None: + # static_broadcasted_argnums = () + # elif isinstance(static_broadcasted_argnums, int): + # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + # elif isinstance(static_broadcasted_argnums, (tuple, list)): + # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + # assert isinstance(static_broadcasted_argnums, (tuple, list)) + # + # # jit functions + # for key in func.steps.keys(): + # step = func.steps[key] + # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # func=step, + # axis_name=axis_name, + # in_axes=in_axes, + # out_axes=out_axes, + # static_broadcasted_argnums=static_broadcasted_argnums, + # devices=devices, + # backend=backend, + # axis_size=axis_size, + # donate_argnums=donate_argnums, + # global_arg_shapes=global_arg_shapes, + # reduce_func=reduce_func, + # f_name=key) + # return func + + if callable(func): + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.pmap(func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + else: + # dynamical variables + dyn_vars = ArrayCollector() + rand_vars = ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + rand_vars[key] = val + else: + dyn_vars[key] = val + + # static broadcast-ed arguments + if static_broadcasted_argnums is None: + static_broadcasted_argnums = () + elif isinstance(static_broadcasted_argnums, int): + static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + elif isinstance(static_broadcasted_argnums, (tuple, list)): + static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + assert isinstance(static_broadcasted_argnums, (tuple, list)) + + # reduce function + if reduce_func is None: + reduce_func = jnp.concatenate + + # jit function + func.__call__ = _make_pmap(dyn_vars=dyn_vars, + rand_vars=rand_vars, + func=func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + reduce_func=reduce_func) + return func + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' + f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 632c6d79e..7b519590a 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,65 +132,19 @@ def evaluate_dyn_vars_with_cache( return stack -def _partial_fun2( - fun: Callable, - args: tuple, - kwargs: dict, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = () -): - num_args = len(args) - - # arguments - static_args = dict() - dyn_args = [] - dyn_arg_ids = dict() - static_argnums = list(static_argnums) - dyn_i = 0 - for i in range(num_args): - if i in static_argnums: - static_argnums.remove(i) - static_args[i] = args[i] - else: - dyn_args.append(args[i]) - dyn_arg_ids[i] = dyn_i - dyn_i += 1 - if len(static_argnums) > 0: - raise ValueError(f"Invalid static_argnums: {static_argnums}") - - # keyword arguments - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames - - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], - **static_kwargs, - **dynkwargs) - - return new_fun, dyn_args, dyn_kwargs - - def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - with_stack: Whether evaluate the function within a local variable stack. + *args: + **kwargs: static_argnums: The static argument indices. static_argnames: The static argument names. @@ -199,30 +153,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + f2, args, kwargs = _partial_fun(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: - f2 = fun + f2, args, kwargs = fun, args, kwargs # evaluate the function fun_in_eval_shape.append(fun) try: - if with_stack: + with jax.ensure_compile_time_eval(): with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + returns = fun(*args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) - else: - stack = None - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + returns = jax.eval_shape(fun, *args, **kwargs) finally: fun_in_eval_shape.pop() - del f2 - if with_stack: - return stack, returns - else: - return returns - + return stack, returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index b7babae8d..5014da0bf 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -189,14 +190,6 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id - @classmethod - def num_of_stack(self): - return len(var_stack_list) - - @classmethod - def is_first_stack(self): - return len(var_stack_list) == 0 - def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -217,6 +210,42 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] +transform_stack: List[Callable] = [] + + +@contextmanager +def new_transform(transform: Any): + transform_stack.append(transform) + try: + yield + finally: + transform_stack.pop() + + +def outermost_stack(): + if len(var_stack_list): + return var_stack_list[0] + else: + return None + + +def outermost_transform(): + if len(transform_stack): + return transform_stack[0] + else: + return None + + +def current_transform_number(): + return len(transform_stack) + + +def _stack_add_read(var: 'Variable'): + pass + + +def _stack_add_write(var: 'Variable'): + pass @register_pytree_node_class diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ead0cf00e..ca070a197 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -5,8 +5,7 @@ import numpy as np from jax.interpreters import xla, batching, ad, mlir - -from brainpy._src.dependency_check import import_numba, check_numba_class, check_taichi_class +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -22,8 +21,6 @@ from brainpy._src.math.op_register.ad_support import defjvp numba = import_numba(error_if_not_found=False) -if numba is not None: - from numba.core.dispatcher import Dispatcher __all__ = [ 'XLACustomOp', @@ -40,8 +37,7 @@ def shape(self) -> Tuple[int, ...]: def dtype(self) -> np.dtype: ... -@check_numba_class -@check_taichi_class + class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. @@ -110,24 +106,30 @@ def __init__( self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function + cpu_checked = False if cpu_kernel is None: - pass - elif isinstance(cpu_kernel, Dispatcher): # numba - register_numba_cpu_translation_rule(self.primitive, cpu_kernel) - elif hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi + cpu_checked = True + if numba is not None: # numba + from numba.core.dispatcher import Dispatcher + if isinstance(cpu_kernel, Dispatcher): + register_numba_cpu_translation_rule(self.primitive, cpu_kernel) + cpu_checked = True + if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) - else: + cpu_checked = True + if not cpu_checked: raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' f'But we got {cpu_kernel}') # gpu function + gpu_checked = False if gpu_kernel is None: - pass - elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + gpu_checked = True + if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) - else: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. ' - f'But we got {gpu_kernel}') + gpu_checked = True + if not gpu_checked: + raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 2af5637b4..5bbd04e0c 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -8,16 +8,14 @@ from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from brainpy._src.dependency_check import import_numba, check_numba_func, check_numba_class +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy.errors import PackageMissingError +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba numba = import_numba(error_if_not_found=False) -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba - -if numba is not None: - from numba.core.dispatcher import Dispatcher __all__ = [ 'CustomOpByNumba', @@ -26,7 +24,6 @@ ] -@check_numba_class class CustomOpByNumba(BrainPyObject): """Creating a XLA custom call operator with Numba JIT on CPU backend. @@ -88,7 +85,6 @@ def __call__(self, *args, **kwargs): return res -@check_numba_func def register_op_with_numba( op_name: str, cpu_func: Callable, @@ -143,6 +139,9 @@ def register_op_with_numba( f'For more information, please refer to the documentation: ' f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + if numba is None: + raise PackageMissingError.by_purpose('numba', 'custom op with numba') + if out_shapes is None: raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' @@ -152,6 +151,7 @@ def register_op_with_numba( prim.multiple_results = multiple_results # user defined function + from numba.core.dispatcher import Dispatcher if not isinstance(cpu_func, Dispatcher): cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 759ecc50c..4b06effdf 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -6,9 +6,15 @@ from jax.core import ShapedArray from jax.lib import xla_client -from brainpy._src.dependency_check import import_numba, check_numba_func +from brainpy._src.dependency_check import import_numba numba = import_numba(error_if_not_found=False) +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object __all__ = [ '_cpu_translation', @@ -18,14 +24,7 @@ if numba is not None: from numba import types, carray, cfunc -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -@check_numba_func def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): target_name, inputs, input_shapes, xla_output_shapes = \ compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) @@ -102,7 +101,7 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): xla_client.register_custom_call_target(target_name, capsule, "cpu") return target_name -@check_numba_func + def compile_cpu_signature_with_numba( c, func, diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index 8c56e52aa..f461f4277 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -7,8 +7,9 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call +from brainpy._src.dependency_check import import_numba +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout -from brainpy._src.dependency_check import import_numba, check_numba_func numba = import_numba(error_if_not_found=False) if numba is not None: @@ -105,8 +106,10 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): ) -@check_numba_func def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, cpu_kernel, @@ -170,7 +173,9 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs): ).results -@check_numba_func def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 43ccac6c8..19800749d 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -10,13 +10,14 @@ from jax.interpreters import ad, xla from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba, check_numba_func +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound numba = import_numba(error_if_not_found=False) + __all__ = [ 'bcsrmm', ] @@ -216,7 +217,6 @@ def blocksparse_matmat_multiply(dense_a, raise Exception('Invalid device: ', device) -@check_numba_func def bcsrmm( A_data: jax.Array, B_data: jax.Array, diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index dd25ef3d4..42969f435 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -3,17 +3,16 @@ from typing import Union, Tuple -import brainpy.math as bm import jax from jax import numpy as jnp from jax.experimental.sparse import csr from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (register_general_batching, - XLACustomOp) +from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import PackageMissingError @@ -23,7 +22,7 @@ 'csrmv', ] -@check_taichi_func + def csrmv( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -64,48 +63,6 @@ def csrmv( - ``vector``: - ``adaptive``: - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - Returns ------- y : ndarry @@ -150,11 +107,11 @@ def raw_csrmv_taichi( transpose: bool = False, ): if ti is None: - raise PackageMissingError(name='taichi', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': - return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose), ] + return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] else: if transpose: prim = _csr_matvec_transpose_heter_p diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py deleted file mode 100644 index cbc710dba..000000000 --- a/brainpy/_src/tools/functions.py +++ /dev/null @@ -1,192 +0,0 @@ -import inspect -from functools import partial -from operator import attrgetter -from types import MethodType - -__all__ = [ - 'compose', 'pipe' -] - - -def identity(x): - """ Identity function. Return x - - >>> identity(3) - 3 - """ - return x - - -def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): - """ Like @property, but returns ``classval`` when used as a class attribute - - >>> class MyClass(object): - ... '''The class docstring''' - ... @instanceproperty(classval=__doc__) - ... def __doc__(self): - ... return 'An object docstring' - ... @instanceproperty - ... def val(self): - ... return 42 - ... - >>> MyClass.__doc__ - 'The class docstring' - >>> MyClass.val is None - True - >>> obj = MyClass() - >>> obj.__doc__ - 'An object docstring' - >>> obj.val - 42 - """ - if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) - - -class InstanceProperty(property): - """ Like @property, but returns ``classval`` when used as a class attribute - - Should not be used directly. Use ``instanceproperty`` instead. - """ - - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): - self.classval = classval - property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - - def __get__(self, obj, type=None): - if obj is None: - return self.classval - return property.__get__(self, obj, type) - - def __reduce__(self): - state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) - return InstanceProperty, state - - -class Compose(object): - """ A composition of functions - - See Also: - compose - """ - __slots__ = 'first', 'funcs' - - def __init__(self, funcs): - funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] - - def __call__(self, *args, **kwargs): - ret = self.first(*args, **kwargs) - for f in self.funcs: - ret = f(ret) - return ret - - def __getstate__(self): - return self.first, self.funcs - - def __setstate__(self, state): - self.first, self.funcs = state - - @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): - """Generate a docstring for the composition of fs. - """ - if not fs: - # Argument name for the docstring. - return '*args, **kwargs' - - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) - - try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) - except AttributeError: - # One of our callables does not have a `__name__`, whatever. - return 'A composition of functions' - - @property - def __name__(self): - try: - return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) - ) - except AttributeError: - return type(self).__name__ - - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first,) + self.funcs))) - - def __eq__(self, other): - if isinstance(other, Compose): - return other.first == self.first and other.funcs == self.funcs - return NotImplemented - - def __ne__(self, other): - equality = self.__eq__(other) - return NotImplemented if equality is NotImplemented else not equality - - def __hash__(self): - return hash(self.first) ^ hash(self.funcs) - - # Mimic the descriptor behavior of python functions. - # i.e. let Compose be called as a method when bound to a class. - # adapted from - # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): - return self if obj is None else MethodType(self, obj) - - # introspection with Signature is only possible from py3.3+ - @instanceproperty - def __signature__(self): - base = inspect.signature(self.first) - last = inspect.signature(self.funcs[-1]) - return base.replace(return_annotation=last.return_annotation) - - __wrapped__ = instanceproperty(attrgetter('first')) - - -def compose(*funcs): - """ Compose functions to operate in series. - - Returns a function that applies other functions in sequence. - - Functions are applied from right to left so that - ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. - - If no arguments are provided, the identity function (f(x) = x) is returned. - - >>> inc = lambda i: i + 1 - >>> compose(str, inc)(3) - '4' - """ - if not funcs: - return identity - if len(funcs) == 1: - return funcs[0] - else: - return Compose(funcs) - - -def pipe(*funcs): - """ Pipe a value through a sequence of functions - - I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` - - We think of the value as progressing through a pipe of several - transformations, much like pipes in UNIX - - - >>> double = lambda i: 2 * i - >>> pipe(double, str)(3) - '6' - """ - return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/progress.py b/brainpy/_src/tools/progress.py new file mode 100644 index 000000000..13b6a1574 --- /dev/null +++ b/brainpy/_src/tools/progress.py @@ -0,0 +1,519 @@ +"""Python utilities required by Keras.""" + +import binascii +import codecs +import importlib +import marshal +import os +import re +import sys +import time +import types as python_types + +import numpy as np + + +# isort: off + + +def func_dump(func): + """Serializes a user defined function. + + Args: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Args: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) + + +class Progbar: + """Displays a progress bar. + + Args: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__( + self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_at_epoch_end = None + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + now = time.time() + info = f" - {now - self._start:.0f}s" + if current == self.target: + self._time_at_epoch_end = now + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + message += "\b" * prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + message += bar + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + info += self._format_time(time_per_unit, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = f" - ETA: {eta_format}" + + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if finalize: + info += "\n" + + message += info + print_msg(message, line_break=False) + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ("%" + str(numdigits) + "d/%d") % (current, self.target) + info = count + info + for k in self._values_order: + info += f" - {k}:" + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + if self._time_at_epoch_end: + time_per_epoch = ( + self._time_at_epoch_end - self._time_at_epoch_start + ) + avg_time_per_step = time_per_epoch / self.target + self._time_at_epoch_start = now + self._time_at_epoch_end = None + info += " -" + self._format_time(time_per_epoch, "epoch") + info += " -" + self._format_time( + avg_time_per_step, self.unit_name + ) + info += "\n" + message += info + print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch) + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + Returns: + a string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" + else: + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 + + def _update_stateful_metrics(self, stateful_metrics): + self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Args: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [ + (i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches) + ] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Args: + arrays: Single array or list of arrays. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError( + "The stop argument has to be None if the value of start " + f"is a list. Received start={start}, stop={stop}" + ) + elif isinstance(arrays, list): + if hasattr(start, "__len__"): + # hdf5 datasets only support list objects as indices + if hasattr(start, "shape"): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + return [ + None + if x is None + else None + if not hasattr(x, "__getitem__") + else x[start:stop] + for x in arrays + ] + else: + if hasattr(start, "__len__"): + if hasattr(start, "shape"): + start = start.tolist() + return arrays[start] + if hasattr(start, "__getitem__"): + return arrays[start:stop] + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def to_snake_case(name): + intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != "_": + return insecure + return "private" + insecure + + +def check_for_unexpected_keys(name, input_dict, expected_values): + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError( + f"Unknown entries in {name} dictionary: {list(unknown)}. " + f"Only expected following keys: {expected_values}" + ) + + +def validate_kwargs( + kwargs, allowed_kwargs, error_message="Keyword argument not understood:" +): + """Checks that all keyword arguments are in the set of allowed keys.""" + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError(error_message, kwarg) + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) + + +def populate_dict_with_module_objects(target_dict, modules, obj_filter): + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj + + +class LazyLoader(python_types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies.""" + + def __init__(self, local_name, parent_module_globals, name): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on + # lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + +def print_msg(message, line_break=True): + """Print the message to absl logging or stdout.""" + if line_break: + sys.stdout.write(message + "\n") + else: + sys.stdout.write(message) + sys.stdout.flush() diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py deleted file mode 100644 index c285e561a..000000000 --- a/brainpy/_src/tools/tests/test_functions.py +++ /dev/null @@ -1,24 +0,0 @@ - -import unittest - -import brainpy as bp -import brainpy.math as bm - - -class TestFunction(unittest.TestCase): - def test_compose(self): - f = lambda a: a + 1 - g = lambda a: a * 10 - fun1 = bp.tools.compose(f, g) - fun2 = bp.tools.pipe(g, f) - - arr = bm.random.randn(10) - r1 = fun1(arr) - r2 = fun2(arr) - groundtruth = f(g(arr)) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, groundtruth)) - bm.clear_buffer_memory() - - - diff --git a/brainpy/errors.py b/brainpy/errors.py index 37d4b9488..453c9c818 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -39,15 +39,11 @@ class PackageMissingError(BrainPyError): """The package missing error. """ - def __init__(self, name: str = None, purpose: str = None): - - if name is None: - super().__init__() - else: - assert purpose, '"purpose" cannot be None when "name" is provided.' - msg = (f'"{name}" must be installed when the user wants to use {purpose}. \n' - f'Please install through "pip install {name}".') - super().__init__(msg) + @classmethod + def by_purpose(cls, name, purpose): + err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + return cls(err) class BackendNotInstalled(BrainPyError): diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 3b0c3f517..e4570f6fd 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - # add as add, + add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 7654731d8..548a987d0 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,7 +59,3 @@ eval_shape as eval_shape, ) -from brainpy._src.math.object_transform.variables import ( - VariableStack as VariableStack, -) - diff --git a/brainpy/tools.py b/brainpy/tools.py index 233269dc5..0f3a4c0ef 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -45,9 +45,4 @@ ) -from brainpy._src.tools.functions import ( - compose as compose, - pipe as pipe, -) - diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 0b78315ab..5c8cba0fd 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,52 +3,13 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. -Advanced Math -------------- .. toctree:: - :maxdepth: 1 - - tutorial_advanced/compilation.ipynb - tutorial_advanced/differentiation.ipynb - - -Interoperation --------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/integrate_flax_into_brainpy.ipynb - tutorial_advanced/integrate_bp_lif_into_flax.ipynb - tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb - - -Brain Dynamics Dedicated Operators ----------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/operator_custom_with_numba.ipynb - tutorial_advanced/operator_custom_with_taichi.ipynb - - -Developer Guides ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/contributing.md - - -Others ------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/advanced_lowdim_analysis.ipynb + :maxdepth: 2 + tutorial_advanced/1_advanced_math.rst + tutorial_advanced/2_interoperation.rst + tutorial_advanced/3_dedicated_operators.rst + tutorial_advanced/4_developer_guides.rst + tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 9ed9cf46a..754e0d81d 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,5 +77,4 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape - VariableStack diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index cc3a38575..11bf53115 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,16 +1,7 @@ BDP Toolboxes ================== - - - This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. - - -Differential Equations ------------------------ - - .. toctree:: :maxdepth: 1 @@ -19,34 +10,11 @@ Differential Equations tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations - - -Toolbox for Modeling -------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights - tutorial_toolbox/inputs - - -Toolbox for Training --------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/optimizers + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient + tutorial_toolbox/inputs - -State Resetting, Saving and Loading ------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 57d18332b..7c9a1c876 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,76 +3,11 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. - -Math Foundation ---------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_math/variables - tutorial_math/control_flows - tutorial_math/Numpy_like_Operations.ipynb - tutorial_math/Dedicated_Operators.ipynb - tutorial_math/einops_in_brainpy.ipynb - - -Model Building with Existing Modules ------------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_building/overview_of_dynamic_model - tutorial_building/build_conductance_neurons_v2.ipynb - tutorial_building/phenon_synapse_models.ipynb - tutorial_building/kinetic_synapse_models.ipynb - tutorial_building/build_network_models - - -Model Building by Customizing New Modules ------------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_building/customize_neuron_models - tutorial_building/customize_synapse_models - tutorial_building/how_to_customze_a_synapse.ipynb - - -Model Simulation ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_simulation/simulation_dsrunner.ipynb - tutorial_simulation/parallel_for_parameter_exploration.ipynb - tutorial_simulation/monitor_per_multiple_steps.ipynb - - -Model Training --------------- - -This tutorial shows how to train a dynamical system from data or task. - -.. toctree:: - :maxdepth: 1 - - tutorial_training/build_training_models.ipynb - tutorial_training/offline_training.ipynb - tutorial_training/online_training.ipynb - tutorial_training/bp_training.ipynb - tutorial_training/esn_introduction.ipynb - - -Model Analysis --------------- - .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - tutorial_analysis/lowdim_analysis - tutorial_analysis/highdim_analysis - tutorial_analysis/decision_making_model + tutorial_math/index + tutorial_building/index + tutorial_simulation/index + tutorial_training/index + tutorial_analysis/index diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 9c7daff55..f98527458 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('delay') + spk = self.delay.at('I') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index fc36845e6..d0dfca11b 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -30,7 +30,7 @@ def train_data(): class RNN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden): super(RNN, self).__init__() - self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True) + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) self.out = bp.layers.Dense(num_hidden, 1) def update(self, x):