From 055862c62e9fa41196c6542acd3883c2c8d95fce Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 23 Nov 2024 19:22:48 +0800 Subject: [PATCH] update --- brainpy/_src/dependency_check.py | 216 +- brainpy/_src/dnn/linear.py | 2766 +++++++++-------- brainpy/_src/math/defaults.py | 18 +- brainpy/_src/math/environment.py | 17 - brainpy/_src/math/event/csr_matmat.py | 45 +- brainpy/_src/math/event/csr_matvec.py | 72 +- .../tests/event_csr_matmat_VS_csr_matmat.py | 285 -- .../event_csrmv_taichi_VS_event_csrmv.py | 254 -- .../event_csrmv_taichi_VS_event_csrmv_grad.py | 271 -- .../_src/math/event/tests/test_event_csrmm.py | 288 -- .../_src/math/event/tests/test_event_csrmv.py | 236 -- brainpy/_src/math/jitconn/event_matvec.py | 36 +- brainpy/_src/math/jitconn/matvec.py | 491 ++- .../tests/event_matvec_jitconn_performance.py | 245 -- ...t_matvec_taichi_VS_jitconn_event_matvec.py | 573 ---- ...vec_taichi_VS_jitconn_event_matvec_grad.py | 589 ---- ...jitconn_matvec_taichi_VS_jitconn_matvec.py | 560 ---- ...nn_matvec_taichi_VS_jitconn_matvec_grad.py | 736 ----- .../tests/matmat_jitconn_performance.py | 61 - .../math/jitconn/tests/matmat_testcase.py | 140 - .../tests/matvec_jitconn_performance.py | 53 - .../math/jitconn/tests/test_event_matvec.py | 435 --- .../jitconn/tests/test_event_matvec_old.py | 564 ---- .../jitconn/tests/test_get_weight_matrix.py | 172 - .../_src/math/jitconn/tests/test_matvec.py | 403 --- .../math/jitconn/tests/test_matvec_old.py | 551 ---- .../csr_matmat_VS_cusparse_csr_matmat.py | 465 --- .../csr_matvec_VS_cusparse_csr_matvec.py | 668 ---- .../sparse/tests/csrmv_taichi_VS_csrmv.py | 250 -- .../tests/csrmv_taichi_VS_csrmv_grad.py | 273 -- brainpy/_src/math/sparse/tests/test_csrmm.py | 293 -- brainpy/_src/math/sparse/tests/test_csrmv.py | 272 -- brainpy/math/__init__.py | 3 - 33 files changed, 1737 insertions(+), 10564 deletions(-) delete mode 100644 brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py delete mode 100644 brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py delete mode 100644 brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmm.py delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv.py delete mode 100644 brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py delete mode 100644 brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py delete mode 100644 brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py delete mode 100644 brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py delete mode 100644 brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py delete mode 100644 brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py delete mode 100644 brainpy/_src/math/jitconn/tests/matmat_testcase.py delete mode 100644 brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_event_matvec.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_event_matvec_old.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_matvec.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_matvec_old.py delete mode 100644 brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py delete mode 100644 brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py delete mode 100644 brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py delete mode 100644 brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmm.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv.py diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 4c0202315..60a394ce1 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,81 +1,46 @@ +import importlib.util import os import sys -from jax.lib import xla_client - __all__ = [ - 'import_taichi', - 'raise_taichi_not_found', - 'import_braintaichi', - 'raise_braintaichi_not_found', - 'import_numba', - 'raise_numba_not_found', - 'import_cupy', - 'import_cupy_jit', - 'raise_cupy_not_found', - 'import_brainpylib_cpu_ops', - 'import_brainpylib_gpu_ops', + 'import_taichi', + 'import_braintaichi', + 'raise_braintaichi_not_found', ] -_minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 2) - -numba = None taichi = None braintaichi = None -cupy = None -cupy_jit = None -brainpylib_cpu_ops = None -brainpylib_gpu_ops = None - -taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n' - '> pip install taichi -U') -numba_install_info = ('We need numba. Please install numba by pip . \n' - '> pip install numba') -cupy_install_info = ('We need cupy. Please install cupy by pip . \n' - 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' - 'For CUDA v12.x > pip install cupy-cuda12x\n') braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n' '> pip install braintaichi -U') - os.environ["TI_LOG_LEVEL"] = "error" def import_taichi(error_if_not_found=True): - """Internal API to import taichi. + """Internal API to import taichi. - If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, - otherwise it will return None. - """ - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - if error_if_not_found: - raise raise_taichi_not_found() - finally: - sys.stdout = old_stdout + If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global taichi + if taichi is None: + if importlib.util.find_spec('taichi') is not None: + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except ModuleNotFoundError as e: + if error_if_not_found: + raise e + finally: + sys.stdout = old_stdout + else: + taichi = None - if taichi is None: - return None - taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2] - minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \ - _minimal_taichi_version[2] - if taichi_version >= minimal_taichi_version: return taichi - else: - raise ModuleNotFoundError(taichi_install_info) -def raise_taichi_not_found(*args, **kwargs): - raise ModuleNotFoundError(taichi_install_info) - def import_braintaichi(error_if_not_found=True): """Internal API to import braintaichi. @@ -84,133 +49,18 @@ def import_braintaichi(error_if_not_found=True): """ global braintaichi if braintaichi is None: - try: - import braintaichi as braintaichi - except ModuleNotFoundError: - if error_if_not_found: - raise_braintaichi_not_found() + if importlib.util.find_spec('braintaichi') is not None: + try: + import braintaichi as braintaichi + except ModuleNotFoundError: + if error_if_not_found: + raise_braintaichi_not_found() + else: + braintaichi = None else: - return None + braintaichi = None return braintaichi + def raise_braintaichi_not_found(): raise ModuleNotFoundError(braintaichi_install_info) - - -def import_numba(error_if_not_found=True): - """ - Internal API to import numba. - - If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, - otherwise it will return None. - """ - global numba - if numba is None: - try: - import numba as numba - except ModuleNotFoundError: - if error_if_not_found: - raise_numba_not_found() - else: - return None - return numba - - -def raise_numba_not_found(): - raise ModuleNotFoundError(numba_install_info) - - -def import_cupy(error_if_not_found=True): - """ - Internal API to import cupy. - - If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, - otherwise it will return None. - """ - global cupy - if cupy is None: - try: - import cupy as cupy - except ModuleNotFoundError: - if error_if_not_found: - raise_cupy_not_found() - else: - return None - return cupy - - -def import_cupy_jit(error_if_not_found=True): - """ - Internal API to import cupy. - - If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, - otherwise it will return None. - """ - global cupy_jit - if cupy_jit is None: - try: - from cupyx import jit as cupy_jit - except ModuleNotFoundError: - if error_if_not_found: - raise_cupy_not_found() - else: - return None - return cupy_jit - - -def raise_cupy_not_found(): - raise ModuleNotFoundError(cupy_install_info) - - -def is_brainpylib_gpu_installed(): - return False if brainpylib_gpu_ops is None else True - - -def import_brainpylib_cpu_ops(): - """ - Internal API to import brainpylib cpu_ops. - """ - global brainpylib_cpu_ops - if brainpylib_cpu_ops is None: - try: - from brainpylib import cpu_ops as brainpylib_cpu_ops - - for _name, _value in brainpylib_cpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_cpu_ops - - -def import_brainpylib_gpu_ops(): - """ - Internal API to import brainpylib gpu_ops. - """ - global brainpylib_gpu_ops - if brainpylib_gpu_ops is None: - try: - from brainpylib import gpu_ops as brainpylib_gpu_ops - - for _name, _value in brainpylib_gpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install GPU version of brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_gpu_ops diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index e517e5563..06fa9413f 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +import importlib.util import numbers from typing import Dict, Optional, Union, Callable @@ -20,1507 +21,1510 @@ from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -bti = import_braintaichi(error_if_not_found=False) -ti = import_taichi(error_if_not_found=False) + +ti = import_taichi() +bti = import_braintaichi() __all__ = [ - 'Dense', 'Linear', - 'Identity', - 'AllToAll', - 'OneToOne', - 'MaskedLinear', - 'CSRLinear', 'EventCSRLinear', - 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', - 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', + 'Dense', 'Linear', + 'Identity', + 'AllToAll', + 'OneToOne', + 'MaskedLinear', + 'CSRLinear', 'EventCSRLinear', + 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', + 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', ] class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): - r"""A linear transformation applied over the last dimension of the input. - - Mathematically, this node can be defined as: - - .. math:: - - y = x \cdot weight + b - - Parameters - ---------- - num_in: int - The number of the input feature. A positive integer. - num_out: int - The number of the output features. A positive integer. - W_initializer: optional, Initializer - The weight initialization. - b_initializer: optional, Initializer - The bias initialization. - mode: Mode - Enable training this node or not. (default True) - """ - - def __init__( - self, - num_in: int, - num_out: int, - W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(), - b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(Dense, self).__init__(mode=mode, name=name) - - # shape - self.num_in = num_in - self.num_out = num_out - if num_in < 0: - raise ValueError(f'Received an invalid value for `num_out`, expected ' - f'a positive integer. Received: num_in={num_in}') - if num_out < 0: - raise ValueError(f'Received an invalid value for `num_out`, expected ' - f'a positive integer. Received: num_out={num_out}') - - # weight initializer - self.W_initializer = W_initializer - self.bias_initializer = b_initializer - is_initializer(W_initializer, 'weight_initializer') - is_initializer(b_initializer, 'bias_initializer', allow_none=True) - - # parameter initialization - W = parameter(self.W_initializer, (num_in, self.num_out)) - b = parameter(self.bias_initializer, (self.num_out,)) - if isinstance(self.mode, bm.TrainingMode): - W = bm.TrainVar(W) - b = None if (b is None) else bm.TrainVar(b) - self.W = W - self.b = b - - # fitting parameters - self.online_fit_by = None # support online training - self.offline_fit_by = None # support offline training - self.fit_record = dict() - - def __repr__(self): - return (f'{self.__class__.__name__}(name={self.name}, ' - f'num_in={self.num_in}, ' - f'num_out={self.num_out}, ' - f'mode={self.mode})') - - def update(self, x): - x = bm.as_jax(x) - res = x @ self.W - if self.b is not None: - res += self.b - - # online fitting data - if share.load('fit', False) and self.online_fit_by is not None: - self.fit_record['input'] = x - self.fit_record['output'] = res - - # offline fitting data - if share.load('fit', False) and self.offline_fit_by is not None: - self.fit_record['input'] = x - self.fit_record['output'] = res - return res - - def online_init(self): - if self.b is None: - num_input = self.num_in - else: - num_input = self.num_in + 1 - self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) - - def online_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - if not isinstance(target, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"target" must be a tensor, but got {type(target)}') - x = fit_record['input'] - y = fit_record['output'] - if x.ndim != 2: - raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {x.shape}') - if target.ndim != 2: - raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {target.shape}') - if x.shape[0] != target.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {x.shape[0]} != {target.shape[0]}.') - if target.shape[1] != y.shape[1]: - raise MathError(f'The output dimension of output and target data should be ' - f'the same, while we got {target.shape[1]} != {y.shape[1]}') - - # data - if self.b is not None: - x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1) - - # fitting - dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name) - - # assign trained weights - if self.b is None: - self.W += dW - else: - db, dW = jnp.split(dW, [1]) - self.b += db[0] - self.W += dW - - def offline_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - """The offline training interface for the Dense node.""" - # data checking - if not isinstance(target, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"targets" must be a tensor, but got {type(target)}') - xs = fit_record['input'] - ys = fit_record['output'] - if xs.ndim != 3: - raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {xs.shape}') - if target.ndim != 3: - raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {target.shape}') - if ys.shape != target.shape: - raise ValueError(f'The shapes of output and target data should be ' - f'the same, while we got {ys.shape} != {target.shape}.') - if xs.shape[0] != target.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {xs.shape[0]} != {target.shape[0]}.') - if xs.shape[1] != target.shape[1]: - raise MathError(f'The time dimension of input and target data should be ' - f'the same, while we got {xs.shape[1]} != {target.shape[1]}') - - # get input and target training data - if self.b is not None: - xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) - - # solve weights by offline training methods - weights = self.offline_fit_by(target, xs, ys) - - # assign trained weights - if self.b is None: - self.W.value = weights - else: - bias, Wff = jnp.split(weights, [1]) - self.W.value = Wff - self.b.value = bias[0] - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.W, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.W, bm.Variable): - self.tracing_variable('W', self.W, self.W.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) + r"""A linear transformation applied over the last dimension of the input. + + Mathematically, this node can be defined as: + + .. math:: + + y = x \cdot weight + b + + Parameters + ---------- + num_in: int + The number of the input feature. A positive integer. + num_out: int + The number of the output features. A positive integer. + W_initializer: optional, Initializer + The weight initialization. + b_initializer: optional, Initializer + The bias initialization. + mode: Mode + Enable training this node or not. (default True) + """ + + def __init__( + self, + num_in: int, + num_out: int, + W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(), + b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(Dense, self).__init__(mode=mode, name=name) + + # shape + self.num_in = num_in + self.num_out = num_out + if num_in < 0: + raise ValueError(f'Received an invalid value for `num_out`, expected ' + f'a positive integer. Received: num_in={num_in}') + if num_out < 0: + raise ValueError(f'Received an invalid value for `num_out`, expected ' + f'a positive integer. Received: num_out={num_out}') + + # weight initializer + self.W_initializer = W_initializer + self.bias_initializer = b_initializer + is_initializer(W_initializer, 'weight_initializer') + is_initializer(b_initializer, 'bias_initializer', allow_none=True) + + # parameter initialization + W = parameter(self.W_initializer, (num_in, self.num_out)) + b = parameter(self.bias_initializer, (self.num_out,)) + if isinstance(self.mode, bm.TrainingMode): + W = bm.TrainVar(W) + b = None if (b is None) else bm.TrainVar(b) + self.W = W + self.b = b + + # fitting parameters + self.online_fit_by = None # support online training + self.offline_fit_by = None # support offline training + self.fit_record = dict() + + def __repr__(self): + return (f'{self.__class__.__name__}(name={self.name}, ' + f'num_in={self.num_in}, ' + f'num_out={self.num_out}, ' + f'mode={self.mode})') + + def update(self, x): + x = bm.as_jax(x) + res = x @ self.W + if self.b is not None: + res += self.b + + # online fitting data + if share.load('fit', False) and self.online_fit_by is not None: + self.fit_record['input'] = x + self.fit_record['output'] = res + + # offline fitting data + if share.load('fit', False) and self.offline_fit_by is not None: + self.fit_record['input'] = x + self.fit_record['output'] = res + return res + + def online_init(self): + if self.b is None: + num_input = self.num_in + else: + num_input = self.num_in + 1 + self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) + + def online_fit(self, + target: ArrayType, + fit_record: Dict[str, ArrayType]): + if not isinstance(target, (bm.ndarray, jnp.ndarray)): + raise MathError(f'"target" must be a tensor, but got {type(target)}') + x = fit_record['input'] + y = fit_record['output'] + if x.ndim != 2: + raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' + f'num_feature), but we got {x.shape}') + if target.ndim != 2: + raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' + f'num_feature), but we got {target.shape}') + if x.shape[0] != target.shape[0]: + raise ValueError(f'Batch size of the input and target data should be ' + f'the same, while we got {x.shape[0]} != {target.shape[0]}.') + if target.shape[1] != y.shape[1]: + raise MathError(f'The output dimension of output and target data should be ' + f'the same, while we got {target.shape[1]} != {y.shape[1]}') + + # data + if self.b is not None: + x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1) + + # fitting + dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name) + + # assign trained weights + if self.b is None: + self.W += dW + else: + db, dW = jnp.split(dW, [1]) + self.b += db[0] + self.W += dW + + def offline_fit(self, + target: ArrayType, + fit_record: Dict[str, ArrayType]): + """The offline training interface for the Dense node.""" + # data checking + if not isinstance(target, (bm.ndarray, jnp.ndarray)): + raise MathError(f'"targets" must be a tensor, but got {type(target)}') + xs = fit_record['input'] + ys = fit_record['output'] + if xs.ndim != 3: + raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' + f'num_feature), but we got {xs.shape}') + if target.ndim != 3: + raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' + f'num_feature), but we got {target.shape}') + if ys.shape != target.shape: + raise ValueError(f'The shapes of output and target data should be ' + f'the same, while we got {ys.shape} != {target.shape}.') + if xs.shape[0] != target.shape[0]: + raise ValueError(f'Batch size of the input and target data should be ' + f'the same, while we got {xs.shape[0]} != {target.shape[0]}.') + if xs.shape[1] != target.shape[1]: + raise MathError(f'The time dimension of input and target data should be ' + f'the same, while we got {xs.shape[1]} != {target.shape[1]}') + + # get input and target training data + if self.b is not None: + xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) + + # solve weights by offline training methods + weights = self.offline_fit_by(target, xs, ys) + + # assign trained weights + if self.b is None: + self.W.value = weights + else: + bias, Wff = jnp.split(weights, [1]) + self.W.value = Wff + self.b.value = bias[0] + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.W, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.W, bm.Variable): + self.tracing_variable('W', self.W, self.W.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) Linear = Dense class Identity(Layer): - r"""A placeholder identity operator that is argument-insensitive. - """ - - def __init__(self, *args, **kwargs) -> None: - super(Identity, self).__init__(*args, **kwargs) - - def update(self, x): - return x - - -if ti is not None and bti is not None: - - # @numba.njit(nogil=True, fastmath=True, parallel=False) - # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): - # out_w[:] = weight - # for i in numba.prange(spike.shape[0]): - # if spike[i]: - # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - - @ti.kernel - def _dense_on_post( - old_w: ti.types.ndarray(ndim=2), - post_spike: ti.types.ndarray(ndim=1), - pre_trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_pre, num_post = out_w.shape - - for i, j in ti.ndrange(num_pre, num_post): - if post_spike[j]: - new_value = out_w[i, j] + pre_trace[i] - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value - else: - out_w[i, j] = old_w[i, j] - - - dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) - - - # @numba.njit(nogil=True, fastmath=True, parallel=False) - # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): - # out_w[:] = weight - # for i in numba.prange(spike.shape[0]): - # if spike[i]: - # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - - @ti.kernel - def _dense_on_pre( - old_w: ti.types.ndarray(ndim=2), - pre_spike: ti.types.ndarray(ndim=1), - post_trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_pre, num_post = out_w.shape - - for i, j in ti.ndrange(num_pre, num_post): - if pre_spike[i]: - new_value = out_w[i, j] + post_trace[j] - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value - else: - out_w[i, j] = old_w[i, j] + r"""A placeholder identity operator that is argument-insensitive. + """ + + def __init__(self, *args, **kwargs) -> None: + super(Identity, self).__init__(*args, **kwargs) + + def update(self, x): + return x - dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) +if ti is not None: + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_post( + old_w: ti.types.ndarray(ndim=2), + post_spike: ti.types.ndarray(ndim=1), + pre_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if post_spike[j]: + new_value = out_w[i, j] + pre_trace[i] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_pre( + old_w: ti.types.ndarray(ndim=2), + pre_spike: ti.types.ndarray(ndim=1), + post_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if pre_spike[i]: + new_value = out_w[i, j] + post_trace[j] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) else: - dense_on_pre_prim = None - dense_on_post_prim = None + dense_on_pre_prim = None + dense_on_post_prim = None def dense_on_pre(weight, spike, trace, w_min, w_max): - if dense_on_pre_prim is None: - raise PackageMissingError.by_purpose('taichi', 'custom operators') + if dense_on_pre_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w_min = jnp.atleast_1d(w_min) - w_max = jnp.atleast_1d(w_max) + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) - weight = bm.as_jax(weight) - spike = bm.as_jax(spike) - trace = bm.as_jax(trace) - w_min = bm.as_jax(w_min) - w_max = bm.as_jax(w_max) - return dense_on_pre_prim(weight, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] + weight = bm.as_jax(weight) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) + return dense_on_pre_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] def dense_on_post(weight, spike, trace, w_min, w_max): - if dense_on_post_prim is None: - raise PackageMissingError.by_purpose('taichi', 'custom operators') + if dense_on_post_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w_min = jnp.atleast_1d(w_min) - w_max = jnp.atleast_1d(w_max) + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) - weight = bm.as_jax(weight) - spike = bm.as_jax(spike) - trace = bm.as_jax(trace) - w_min = bm.as_jax(w_min) - w_max = bm.as_jax(w_max) - return dense_on_post_prim(weight, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] + weight = bm.as_jax(weight) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) + return dense_on_post_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] class AllToAll(Layer, SupportSTDP): - """Synaptic matrix multiplication with All2All connections. - - Args: - num_pre: int. The number of neurons in the presynaptic neuron group. - num_post: int. The number of neurons in the postsynaptic neuron group. - weight: The synaptic weights. - sharding: The sharding strategy. - include_self: bool. Whether connect the neuron with at the same position. - mode: Mode. The computing mode. - name: str. The object name. - """ - - def __init__( - self, - num_pre: int, - num_post: int, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - include_self: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(mode=mode, name=name) - - self.num_pre = num_pre - self.num_post = num_post - self.include_self = include_self - self.sharding = sharding - - weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, pre_val): - if bm.ndim(self.weight) == 0: # weight is a scalar - if isinstance(self.mode, bm.BatchingMode): - assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' - post_val = bm.sum(pre_val, keepdims=True, axis=1) - else: - assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' - post_val = bm.sum(pre_val) - if not self.include_self: - if self.num_pre == self.num_post: - post_val = post_val - pre_val - elif self.num_pre > self.num_post: - val = pre_val[:self.num_post] - post_val = post_val - val - else: - val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) - post_val = post_val - val - post_val = self.weight * post_val - - else: # weight is a matrix - assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' - if not self.include_self: - post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) - else: - post_val = pre_val @ self.weight - return post_val - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) + """Synaptic matrix multiplication with All2All connections. + + Args: + num_pre: int. The number of neurons in the presynaptic neuron group. + num_post: int. The number of neurons in the postsynaptic neuron group. + weight: The synaptic weights. + sharding: The sharding strategy. + include_self: bool. Whether connect the neuron with at the same position. + mode: Mode. The computing mode. + name: str. The object name. + """ + + def __init__( + self, + num_pre: int, + num_post: int, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + include_self: bool = True, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num_pre = num_pre + self.num_post = num_post + self.include_self = include_self + self.sharding = sharding + + weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, pre_val): + if bm.ndim(self.weight) == 0: # weight is a scalar + if isinstance(self.mode, bm.BatchingMode): + assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' + post_val = bm.sum(pre_val, keepdims=True, axis=1) + else: + assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' + post_val = bm.sum(pre_val) + if not self.include_self: + if self.num_pre == self.num_post: + post_val = post_val - pre_val + elif self.num_pre > self.num_post: + val = pre_val[:self.num_post] + post_val = post_val - val + else: + val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) + post_val = post_val - val + post_val = self.weight * post_val + + else: # weight is a matrix + assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' + if not self.include_self: + post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) + else: + post_val = pre_val @ self.weight + return post_val + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) class OneToOne(Layer, SupportSTDP): - """Synaptic matrix multiplication with One2One connection. - - Args: - num: int. The number of neurons. - weight: The synaptic weight. - sharding: The sharding strategy. - mode: The computing mode. - name: The object name. - - """ - - def __init__( - self, - num: int, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(mode=mode, name=name) - - self.num = num - self.sharding = sharding - - weight = init.parameter(weight, (self.num,), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, pre_val): - return pre_val * self.weight - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value += spike * trace - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value += spike * trace + """Synaptic matrix multiplication with One2One connection. + + Args: + num: int. The number of neurons. + weight: The synaptic weight. + sharding: The sharding strategy. + mode: The computing mode. + name: The object name. + + """ + + def __init__( + self, + num: int, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num = num + self.sharding = sharding + + weight = init.parameter(weight, (self.num,), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, pre_val): + return pre_val * self.weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value += spike * trace + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value += spike * trace class MaskedLinear(Layer, SupportSTDP): - r"""Synaptic matrix multiplication with masked dense computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a dense matrix. - - >>> import brainpy as bp - >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), - >>> weight=0.1) - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - mask_fun: Masking function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - mask_fun: Callable = Identity(), - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - self.mask_fun = mask_fun - - # weight - weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - # connection - self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding) - - def update(self, x): - return x @ self.mask_fun(self.weight * self.mask) - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) + r"""Synaptic matrix multiplication with masked dense computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a dense matrix. + + >>> import brainpy as bp + >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), + >>> weight=0.1) + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + mask_fun: Masking function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + mask_fun: Callable = Identity(), + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + self.mask_fun = mask_fun + + # weight + weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + # connection + self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding) + + def update(self, x): + return x @ self.mask_fun(self.weight * self.mask) + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) class _CSRLayer(Layer, SupportSTDP): - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - assert sharding is None, 'Currently this model does not support sharding.' - self.conn = conn - self.sharding = sharding - self.transpose = transpose - - # connection - self.indices, self.indptr = self.conn.require('csr') - - # weight - weight = init.parameter(weight, (self.indices.size,)) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if bm.isscalar(self.weight): - raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') - if self.weight.shape != self.indices.shape: - raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) - if on_post is not None: # update on postsynaptic spike - if not hasattr(self, '_pre_ids'): - with jax.ensure_compile_time_eval(): - self._pre_ids, self._post_indptr, self.w_indices = csr2csc( - [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) - ) - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr, - self.w_indices, spike, trace, w_min, w_max) + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') + if self.weight.shape != self.indices.shape: + raise ValueError( + f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: # update on presynaptic spike + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, + w_max) + if on_post is not None: # update on postsynaptic spike + if not hasattr(self, '_pre_ids'): + with jax.ensure_compile_time_eval(): + self._pre_ids, self._post_indptr, self.w_indices = csr2csc( + [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) + ) + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr, + self.w_indices, spike, trace, w_min, w_max) class CSRLinear(_CSRLayer): - r"""Synaptic matrix multiplication with CSR sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - method: str = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - self.method = method - - def update(self, x): - if x.ndim == 1: - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + r"""Synaptic matrix multiplication with CSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + method: str = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + self.method = method + + def update(self, x): + if x.ndim == 1: + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) class EventCSRLinear(_CSRLayer): - r"""Synaptic matrix multiplication with event CSR sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - - def update(self, x): - if x.ndim == 1: - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) + r"""Synaptic matrix multiplication with event CSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) if ti is not None: - @ti.kernel - def _csr_on_pre_update( - old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) - spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) - trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) - w_min: ti.types.ndarray(ndim=1), # scalar - w_max: ti.types.ndarray(ndim=1), # scalar - out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_pre = spike.shape[0] - for i_pre in range(num_pre): - if spike[i_pre]: - for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): - out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) - else: - for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): - out_w[i_syn] = old_w[i_syn] - - - csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) - - - @ti.kernel - def _coo_on_pre_update( - old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) - post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) - w_min: ti.types.ndarray(ndim=1), # scalar - w_max: ti.types.ndarray(ndim=1), # scalar - out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_syn = old_w.shape[0] - for i_syn in range(num_syn): - if pre_spike[pre_ids[i_syn]]: # pre spike - out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) - else: - out_w[i_syn] = old_w[i_syn] - - - coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) - - - @ti.kernel - def _coo_on_post_update( - old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) - pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) - w_min: ti.types.ndarray(ndim=1), # scalar - w_max: ti.types.ndarray(ndim=1), # scalar - out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_syn = old_w.shape[0] - for i_syn in range(num_syn): - if post_spike[post_ids[i_syn]]: # pre spike - out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) - else: - out_w[i_syn] = old_w[i_syn] - - - coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) - - - # @numba.njit(nogil=True, fastmath=True, parallel=False) - # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - # out_w[:] = w - # w_min = w_min[()] - # w_max = w_max[()] - # for i in numba.prange(spike.shape[0]): # post id - # if spike[i]: - # for k in range(indptr[i], indptr[i + 1]): - # j = post_ids[k] # pre id - # l = w_ids[k] # syn id - # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) - - @ti.kernel - def _csc_on_post_update( - old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) - w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) - pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) - w_min: ti.types.ndarray(ndim=1), # scalar - w_max: ti.types.ndarray(ndim=1), # scalar - out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) - ): - w_min0 = w_min[0] - w_max0 = w_max[0] - num_post = post_spike.shape[0] - for i_post in range(num_post): - if post_spike[i_post]: - for k in range(indptr[i_post], indptr[i_post + 1]): - i_syn = w_ids[k] # syn id - out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) - else: - for k in range(indptr[i_post], indptr[i_post + 1]): - i_syn = w_ids[k] # syn id - out_w[i_syn] = old_w[i_syn] - - - csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) + @ti.kernel + def _csr_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) + spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre = spike.shape[0] + for i_pre in range(num_pre): + if spike[i_pre]: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) + else: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = old_w[i_syn] + + + csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + + + @ti.kernel + def _coo_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if pre_spike[pre_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + + + @ti.kernel + def _coo_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if post_spike[post_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + # out_w[:] = w + # w_min = w_min[()] + # w_max = w_max[()] + # for i in numba.prange(spike.shape[0]): # post id + # if spike[i]: + # for k in range(indptr[i], indptr[i + 1]): + # j = post_ids[k] # pre id + # l = w_ids[k] # syn id + # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + + @ti.kernel + def _csc_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) + w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_post = post_spike.shape[0] + for i_post in range(num_post): + if post_spike[i_post]: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) + else: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = old_w[i_syn] + + + csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) else: - csr_on_pre_update_prim = None - coo_on_pre_update_prim = None - csc_on_post_update_prim = None + csr_on_pre_update_prim = None + coo_on_pre_update_prim = None + csc_on_post_update_prim = None def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): - if csr_on_pre_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') - - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w_min = jnp.atleast_1d(w_min) - w_max = jnp.atleast_1d(w_max) - - w = bm.as_jax(w) - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - spike = bm.as_jax(spike) - trace = bm.as_jax(trace) - w_min = bm.as_jax(w_min) - w_max = bm.as_jax(w_max) - return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + if csr_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + + w = bm.as_jax(w) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) + return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): - if coo_on_pre_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') + if coo_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w_min = jnp.atleast_1d(w_min) - w_max = jnp.atleast_1d(w_max) + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) - w = bm.as_jax(w) - pre_ids = bm.as_jax(pre_ids) - post_ids = bm.as_jax(post_ids) - spike = bm.as_jax(spike) - trace = bm.as_jax(trace) - w_min = bm.as_jax(w_min) - w_max = bm.as_jax(w_max) + w = bm.as_jax(w) + pre_ids = bm.as_jax(pre_ids) + post_ids = bm.as_jax(post_ids) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) - return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): - if csc_on_post_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') - - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w_min = jnp.atleast_1d(w_min) - w_max = jnp.atleast_1d(w_max) - - w = bm.as_jax(w) - post_ids = bm.as_jax(post_ids) - indptr = bm.as_jax(indptr) - w_ids = bm.as_jax(w_ids) - post_spike = bm.as_jax(post_spike) - pre_trace = bm.as_jax(pre_trace) - w_min = bm.as_jax(w_min) - w_max = bm.as_jax(w_max) - return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + if csc_on_post_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + + w = bm.as_jax(w) + post_ids = bm.as_jax(post_ids) + indptr = bm.as_jax(indptr) + w_ids = bm.as_jax(w_ids) + post_spike = bm.as_jax(post_spike) + pre_trace = bm.as_jax(pre_trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) + return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] class CSCLinear(Layer): - r"""Synaptic matrix multiplication with CSC sparse computation. + r"""Synaptic matrix multiplication with CSC sparse computation. - It performs the computation of: + It performs the computation of: - .. math:: + .. math:: - y = x @ M + y = x @ M - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSC sparse matrix. + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSC sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding class BcsrMM(Layer): - r"""Synaptic matrix multiplication with BCSR sparse computation. + r"""Synaptic matrix multiplication with BCSR sparse computation. - It performs the computation of: + It performs the computation of: - .. math:: + .. math:: - y = x @ M + y = x @ M - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a BCSR sparse matrix. + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSR sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding class BcscMM(Layer): - r"""Synaptic matrix multiplication with BCSC sparse computation. + r"""Synaptic matrix multiplication with BCSC sparse computation. - It performs the computation of: + It performs the computation of: - .. math:: + .. math:: - y = x @ M + y = x @ M - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a BCSC sparse matrix. + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSC sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding class JitLinear(Layer): - def get_conn_matrix(self): - pass + def get_conn_matrix(self): + pass class JitFPHomoLayer(JitLinear): - def get_conn_matrix(self): - return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + def get_conn_matrix(self): + return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) class JitFPUniformLayer(JitLinear): - def get_conn_matrix(self): - return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + def get_conn_matrix(self): + return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) class JitFPNormalLayer(JitLinear): - def get_conn_matrix(self): - return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + def get_conn_matrix(self): + return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) class JitFPHomoLinear(JitFPHomoLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) class JitFPUniformLinear(JitFPUniformLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + def _batch_mv(self, x): + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPNormalLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) +class JitFPNormalLinear(JitFPNormalLayer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError -class EventJitFPHomoLinear(JitFPHomoLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 1000000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + def _batch_mv(self, x): + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPUniformLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - +class EventJitFPHomoLinear(JitFPHomoLayer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = True, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 1000000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError -class EventJitFPNormalLinear(JitFPNormalLayer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + + +class EventJitFPUniformLinear(JitFPUniformLayer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = True, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPNormalLinear(JitFPNormalLayer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = True, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index eab8b9b66..f82a90ad7 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -1,13 +1,10 @@ import jax.numpy as jnp from jax import config -from brainpy._src.dependency_check import import_taichi from .modes import NonBatchingMode from .scales import IdScaling -__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] - -ti = import_taichi(error_if_not_found=False) +__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'float_', 'complex_'] # Default computation mode. mode = NonBatchingMode() @@ -36,16 +33,3 @@ # default return array type numpy_func_return = 'bp_array' # 'bp_array','jax_array' - - -if ti is not None: - # Default integer data type in Taichi. - ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - - # Default float data type in Taichi. - ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 - -else: - ti_int = None - ti_float = None - diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index ebbb8b6a3..984f3137e 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -18,9 +18,6 @@ from . import scales from . import defaults from .object_transform import naming -from brainpy._src.dependency_check import import_taichi - -ti = import_taichi(error_if_not_found=False) __all__ = [ # context manage for environment setting @@ -459,16 +456,10 @@ def set_float(dtype: type): """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 - if ti is not None: - defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 - if ti is not None: - defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 - if ti is not None: - defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError @@ -494,20 +485,12 @@ def set_int(dtype: type): """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 - if ti is not None: - defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 - if ti is not None: - defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 - if ti is not None: - defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 - if ti is not None: - defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index b78afad70..0db589ae1 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -3,16 +3,15 @@ from typing import Union, Tuple - from jax import numpy as jnp -from brainpy._src.math.ndarray import Array from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found +from brainpy._src.math.ndarray import Array bti = import_braintaichi(error_if_not_found=False) __all__ = [ - 'csrmm', + 'csrmm', ] @@ -25,23 +24,23 @@ def csrmm( shape: Tuple[int, int], transpose: bool = False, ): - """Product of CSR sparse matrix and a dense event matrix. - - Args: - data : array of shape ``(nse,)``, float. - indices : array of shape ``(nse,)`` - indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` - B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and - dtype ``data.dtype`` - shape : length-2 tuple representing the matrix shape - transpose : boolean specifying whether to transpose the sparse matrix - before computing. - - Returns: - C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. - """ - if bti is None: - raise_braintaichi_not_found() - - return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + """Product of CSR sparse matrix and a dense event matrix. + + Args: + data : array of shape ``(nse,)``, float. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + if bti is None: + raise_braintaichi_not_found() + + return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index 3969ee6bf..d9c39370e 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -18,13 +18,11 @@ bti = import_braintaichi(error_if_not_found=False) - __all__ = [ - 'csrmv' + 'csrmv' ] - def csrmv( data: Union[float, jax.Array], indices: jax.Array, @@ -34,37 +32,37 @@ def csrmv( shape: Tuple[int, int], transpose: bool = False, ) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - if bti is None: - raise_braintaichi_not_found() - - return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + if bti is None: + raise_braintaichi_not_found() + + return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py deleted file mode 100644 index 872c69e14..000000000 --- a/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py +++ /dev/null @@ -1,285 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -size = [ - (100, 100, 100), - (100, 1000, 100), - (1000, 1000, 100), - (1000, 1000, 1000), - (100, 10000, 100), - (10000, 100, 1000), - (1000, 100, 10000), - (10000, 10000, 1000), - (10000, 1000, 10000), - (10000, 10000, 10000), - (20000, 20000, 20000), -] - -values_type = [ - 'heter', - 'homo', -] -events_type = ['bool', - 'float', - ] -transpose = [ - # True, - False -] - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - - -@partial(jax.jit, static_argnums=(4, 5)) -def csrmm(weight, indices, indptr, matrix, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - return r - - -@partial(jax.jit, static_argnums=(4, 5)) -def event_csrmm(weight, indices, indptr, matrix, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.event.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - return r - - -def test_sparse_csrmm(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) - matrix2_shape = (shape[1], shape[2]) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') - matrix = rng.random(matrix2_shape) - matrix = bm.as_jax(matrix) - weight = 1. - - if events_type == 'float': - matrix = matrix.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time19 = time.time() - - result1 = result - - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time21 = time.time() - - result2 = result - - time22 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time39 = time.time() - - csrmm_time1 = (time1 - time0) * 1000 - csrmm_time2 = (time3 - time2) * 1000 - csrmm_time3 = (time5 - time4) * 1000 - csrmm_time4 = (time7 - time6) * 1000 - csrmm_time5 = (time9 - time8) * 1000 - csrmm_time6 = (time11 - time10) * 1000 - csrmm_time7 = (time13 - time12) * 1000 - csrmm_time8 = (time15 - time14) * 1000 - csrmm_time9 = (time17 - time16) * 1000 - csrmm_time10 = (time19 - time18) * 1000 - event_csrmm_time1 = (time21 - time20) * 1000 - event_csrmm_time2 = (time23 - time22) * 1000 - event_csrmm_time3 = (time25 - time24) * 1000 - event_csrmm_time4 = (time27 - time26) * 1000 - event_csrmm_time5 = (time29 - time28) * 1000 - event_csrmm_time6 = (time31 - time30) * 1000 - event_csrmm_time7 = (time33 - time32) * 1000 - event_csrmm_time8 = (time35 - time34) * 1000 - event_csrmm_time9 = (time37 - time36) * 1000 - event_csrmm_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('csrmm_1: ', csrmm_time1, 'ms') - print('csrmm_3: ', csrmm_time3, 'ms') - print('csrmm_5: ', csrmm_time5, 'ms') - print('csrmm_7: ', csrmm_time7, 'ms') - print('csrmm_9: ', csrmm_time9, 'ms') - print('event_csrmm_1: ', event_csrmm_time1, 'ms') - print('event_csrmm_3: ', event_csrmm_time3, 'ms') - print('event_csrmm_5: ', event_csrmm_time5, 'ms') - print('event_csrmm_7: ', event_csrmm_time7, 'ms') - print('event_csrmm_9: ', event_csrmm_time9, 'ms') - - r = bm.allclose(result1, result2) - if not r: - print('result1: ', result1) - print('result2: ', result2) - - return csrmm_time1, csrmm_time2, csrmm_time3, csrmm_time4, csrmm_time5, \ - csrmm_time6, csrmm_time7, csrmm_time8, csrmm_time9, csrmm_time10, \ - event_csrmm_time1, event_csrmm_time2, event_csrmm_time3, event_csrmm_time4, event_csrmm_time5, \ - event_csrmm_time6, event_csrmm_time7, event_csrmm_time8, event_csrmm_time9, event_csrmm_time10 - - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame( - columns=['shape', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', - 'csrmm time1(ms)', 'csrmm time2(ms)', 'csrmm time3(ms)', 'csrmm time4(ms)', - 'csrmm time5(ms)', - 'csrmm time6(ms)', 'csrmm time7(ms)', 'csrmm time8(ms)', 'csrmm time9(ms)', - 'csrmm time10(ms)', - 'event_csrmm time1(ms)', 'event_csrmm time2(ms)', 'event_csrmm time3(ms)', 'event_csrmm time4(ms)', - 'event_csrmm time5(ms)', - 'event_csrmm time6(ms)', 'event_csrmm time7(ms)', 'event_csrmm time8(ms)', 'event_csrmm time9(ms)', - 'event_csrmm time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape in size: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ - csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ - event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ - event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( - shape, - _values_type, - _events_type, - _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, - _transpose, - csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, - csrmm_time_5, - csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, - csrmm_time_10, - event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, - event_csrmm_time_5, - event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, - event_csrmm_time_10] - df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape in size: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ - csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ - event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ - event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( - shape, - _values_type, - _events_type, - _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, - _transpose, - csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, - csrmm_time_5, - csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, - csrmm_time_10, - event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, - event_csrmm_time_5, - event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, - event_csrmm_time_10] - df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py deleted file mode 100644 index 3ac1e0ee2..000000000 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py +++ /dev/null @@ -1,254 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -s = [1000, 5000, 10000, 20000, 25000, 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 -] - - - -values_type = [ - 'homo', - 'heter' - ] -events_type = [ - 'bool', - 'float', - ] -transpose = [ - True, - False - ] - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -@partial(jax.jit, static_argnums=(4, 5)) -def event_csrmv_taichi(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] - return r - -@partial(jax.jit, static_argnums=(4, 5)) -def event_csrmv(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) - return r - -def test_event_csrmv(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time0 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time20 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - # assert(jnp.allclose(result1[0], result2)) - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/event_csrmv_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/event_csrmv_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py deleted file mode 100644 index 98793e600..000000000 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py +++ /dev/null @@ -1,271 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -s = [1000, 5000, 10000, 20000, 25000, 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 -] - - - -values_type = [ - 'homo', - 'heter' - ] -events_type = [ - 'bool', - 'float', - ] -transpose = [ - True, - False - ] - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - -@partial(jax.jit, static_argnums=(4, 5)) -def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - return r - -@partial(jax.jit, static_argnums=(4, 5)) -def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.event.csrmv), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - return r - - -def test_event_csrmv(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py deleted file mode 100644 index 0df7fc8ff..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ /dev/null @@ -1,288 +0,0 @@ -# -*- coding: utf-8 -*- -import os -from functools import partial - -import jax -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -# bm.set_platform('gpu') - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_csrmm(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmm, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones(indices.shape) * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_vmap(self, transpose, shape, homo_data): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # vmap 'data' - f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - heter_data = bm.ones((10, indices.shape[0])) * homo_data - r1 = f1(heter_data) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - # vmap 'events' - heter_data = bm.ones(indices.shape) * homo_data - f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmm, homo_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) - r3 = f3(matrix) - r4 = f4(matrix) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() - if transpose else - ((dense * a) @ matrix).sum()), - argnums=0) - r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.event.csrmm))( - bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() - if transpose else - ((dense * homo_data) @ m).sum()), - argnums=0) - r3 = dense_f2(matrix.astype(float)) - r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - bm.asarray([homo_data]), indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_vmap(self, transpose, shape): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # vmap 'data' - f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - r1 = f1(vmap_data) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - # vmap 'events' - heter_data = bm.ones(indices.shape) - f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) - r3 = f3(matrix) - r4 = f4(matrix) - self.assertTrue(bm.allclose(r3, r4)) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_grad(self, transpose, shape): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - r1 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py deleted file mode 100644 index ea8303476..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ /dev/null @@ -1,236 +0,0 @@ -# -*- coding: utf-8 -*- -import os -from functools import partial - -import jax -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_event_csr_matvec_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000)], - homo_data=[1.], - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - homo_data = bm.asarray([homo_data]) - - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r1 = (events @ dense) if transpose else (dense @ events) - r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000)], - homo_data=[1.], - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - homo_data = bm.asarray([homo_data]) - - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2))) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000)], - homo_data=[1.], - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - homo_data = bm.asarray([homo_data]) - - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, - transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, - transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000), ] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = bm.event.csrmv(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2))) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), (10, 1000)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(bm.sparse.csrmv))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmv))( - data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index 80bba29b9..e4a33ce0c 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -4,17 +4,17 @@ import jax +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found from brainpy._src.math.jitconn.matvec import (mv_prob_homo, mv_prob_uniform, mv_prob_normal) -from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found bti = import_braintaichi(error_if_not_found=False) __all__ = [ - 'event_mv_prob_homo', - 'event_mv_prob_uniform', - 'event_mv_prob_normal', + 'event_mv_prob_homo', + 'event_mv_prob_uniform', + 'event_mv_prob_normal', ] @@ -28,12 +28,12 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if bti is None: - raise_braintaichi_not_found() - return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ @@ -50,10 +50,10 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if bti is None: - raise_braintaichi_not_found() - return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ @@ -70,10 +70,10 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if bti is None: - raise_braintaichi_not_found() - return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index be4b19d19..4481e6fd6 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -1,25 +1,20 @@ # -*- coding: utf-8 -*- -import numbers from typing import Tuple, Optional, Union import jax -import numpy as np + from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found -from brainpy._src.math import defaults -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy.errors import PackageMissingError -from jax import numpy as jnp bti = import_braintaichi(error_if_not_found=False) __all__ = [ - 'mv_prob_homo', - 'mv_prob_uniform', - 'mv_prob_normal', - 'get_homo_weight_matrix', - 'get_uniform_weight_matrix', - 'get_normal_weight_matrix' + 'mv_prob_homo', + 'mv_prob_uniform', + 'mv_prob_normal', + 'get_homo_weight_matrix', + 'get_uniform_weight_matrix', + 'get_normal_weight_matrix' ] @@ -33,59 +28,59 @@ def mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. - .. warning:: + .. warning:: - This API may change in the future. + This API may change in the future. - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - .. note:: + .. note:: - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. + Parameters + ---------- + vector: Array, ndarray + The vector. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if bti is None: - raise_braintaichi_not_found() + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if bti is None: + raise_braintaichi_not_found() - return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_uniform( @@ -99,61 +94,61 @@ def mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if bti is None: - raise_braintaichi_not_found() - - return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if bti is None: + raise_braintaichi_not_found() + + return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_normal( @@ -167,60 +162,60 @@ def mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if bti is None: - raise_braintaichi_not_found() - return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def get_homo_weight_matrix( @@ -232,31 +227,32 @@ def get_homo_weight_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`. - - Parameters - ---------- - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The connection matrix :math:`M`. - """ - if bti is None: - raise_braintaichi_not_found() - return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`. + + Parameters + ---------- + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The connection matrix :math:`M`. + """ + if bti is None: + raise_braintaichi_not_found() + return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) def get_uniform_weight_matrix( @@ -269,36 +265,36 @@ def get_uniform_weight_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Get the weight matrix :math:`M` with a uniform distribution for its value. - - Parameters - ---------- - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The weight matrix :math:`M`. - """ - if bti is None: - raise_braintaichi_not_found() - return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + r"""Get the weight matrix :math:`M` with a uniform distribution for its value. + + Parameters + ---------- + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if bti is None: + raise_braintaichi_not_found() + return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def get_normal_weight_matrix( @@ -311,33 +307,32 @@ def get_normal_weight_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Get the weight matrix :math:`M` with a normal distribution for its value. - - Parameters - ---------- - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The weight matrix :math:`M`. - """ - if bti is None: - raise_braintaichi_not_found() - return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed, - shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - + r"""Get the weight matrix :math:`M` with a normal distribution for its value. + + Parameters + ---------- + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if bti is None: + raise_braintaichi_not_found() + return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py deleted file mode 100644 index 2c1ca7110..000000000 --- a/brainpy/_src/math/jitconn/tests/event_matvec_jitconn_performance.py +++ /dev/null @@ -1,245 +0,0 @@ -from time import time - -from jax import jit - -import brainpy as bp -import brainpy.math as bm - - -def compare_sparse_ops(platform='cpu'): - """ - - GPU - --- - shape = (1000, 1000), prob = 0.1, transpose = True - csr sparse 0.09568500518798828 s - jit conn 0.12936949729919434 s - - shape = (1000, 1000), prob = 0.1, transpose = False - csr sparse 0.09957313537597656 s - jit conn 0.1456453800201416 s - - shape = (1000, 1000), prob = 0.2, transpose = True - csr sparse 0.1014559268951416 s - jit conn 0.16193556785583496 s - - shape = (1000, 1000), prob = 0.2, transpose = False - csr sparse 0.10938715934753418 s - jit conn 0.14464354515075684 s - - shape = (1000, 1000), prob = 0.4, transpose = True - csr sparse 0.14374589920043945 s - jit conn 0.1551048755645752 s - - shape = (1000, 1000), prob = 0.4, transpose = False - csr sparse 0.14356279373168945 s - jit conn 0.15198969841003418 s - - shape = (1000, 1000), prob = 0.6, transpose = True - csr sparse 0.1429135799407959 s - jit conn 0.15459179878234863 s - - shape = (1000, 1000), prob = 0.6, transpose = False - csr sparse 0.14870882034301758 s - jit conn 0.15899157524108887 s - - shape = (1000, 1000), prob = 0.8, transpose = True - csr sparse 0.1489548683166504 s - jit conn 0.1636965274810791 s - - shape = (1000, 1000), prob = 0.8, transpose = False - csr sparse 0.09073925018310547 s - jit conn 0.17296433448791504 s - - shape = (1000, 10000), prob = 0.1, transpose = True - csr sparse 0.14572954177856445 s - jit conn 0.15570378303527832 s - - shape = (1000, 10000), prob = 0.1, transpose = False - csr sparse 0.14201974868774414 s - jit conn 0.2694075107574463 s - - shape = (1000, 10000), prob = 0.2, transpose = True - csr sparse 0.1480388641357422 s - jit conn 0.14784669876098633 s - - shape = (1000, 10000), prob = 0.2, transpose = False - csr sparse 0.14451289176940918 s - jit conn 0.4144716262817383 s - - shape = (1000, 10000), prob = 0.4, transpose = True - csr sparse 0.14377927780151367 s - jit conn 0.15256381034851074 s - - shape = (1000, 10000), prob = 0.4, transpose = False - csr sparse 0.1487278938293457 s - jit conn 0.41004467010498047 s - - shape = (1000, 10000), prob = 0.6, transpose = True - csr sparse 0.1689896583557129 s - jit conn 0.18367314338684082 s - - shape = (1000, 10000), prob = 0.6, transpose = False - csr sparse 0.15153169631958008 s - jit conn 0.4159865379333496 s - - shape = (1000, 10000), prob = 0.8, transpose = True - csr sparse 0.15267014503479004 s - jit conn 0.16814088821411133 s - - shape = (1000, 10000), prob = 0.8, transpose = False - csr sparse 0.1320178508758545 s - jit conn 0.5114090442657471 s - - shape = (10000, 10000), prob = 0.1, transpose = True - csr sparse 0.15414834022521973 s - jit conn 0.15847539901733398 s - - shape = (10000, 10000), prob = 0.1, transpose = False - csr sparse 0.1557462215423584 s - jit conn 0.18897342681884766 s - - shape = (10000, 10000), prob = 0.2, transpose = True - csr sparse 0.28719663619995117 s - jit conn 0.3945181369781494 s - - shape = (10000, 10000), prob = 0.2, transpose = False - csr sparse 0.29045557975769043 s - jit conn 0.2662692070007324 s - - shape = (10000, 10000), prob = 0.4, transpose = True - csr sparse 0.26814866065979004 s - jit conn 0.41262269020080566 s - - shape = (10000, 10000), prob = 0.4, transpose = False - csr sparse 0.14010882377624512 s - jit conn 0.30821704864501953 s - - shape = (10000, 10000), prob = 0.6, transpose = True - csr sparse 0.34110474586486816 s - jit conn 0.44765257835388184 s - - shape = (10000, 10000), prob = 0.6, transpose = False - csr sparse 0.14516901969909668 s - jit conn 0.42423462867736816 s - - shape = (10000, 10000), prob = 0.8, transpose = True - csr sparse 0.38806986808776855 s - jit conn 0.5052323341369629 s - - shape = (10000, 10000), prob = 0.8, transpose = False - csr sparse 0.13016152381896973 s - jit conn 0.4791419506072998 s - - shape = (50000, 50000), prob = 0.1, transpose = True - csr sparse 0.1485145092010498 s - jit conn 0.6013796329498291 s - - shape = (50000, 50000), prob = 0.1, transpose = False - csr sparse 0.2520942687988281 s - jit conn 0.5886740684509277 s - - shape = (50000, 50000), prob = 0.2, transpose = True - csr sparse 0.41227173805236816 s - jit conn 1.0801291465759277 s - - shape = (50000, 50000), prob = 0.2, transpose = False - csr sparse 0.5962152481079102 s - jit conn 1.1053071022033691 s - - shape = (50000, 50000), prob = 0.4, transpose = True - Killed - """ - - bm.set_platform(platform) - - weight = 1. - seed = 1234 - - all_shapes = [ - (int(1e3), int(1e3)), - (int(1e3), int(1e4)), - (int(1e4), int(1e4)), - (int(5e4), int(5e4)), - (int(5e4), int(1e5)), - ] - - for shape in all_shapes: - for prob in [0.1, 0.2, 0.4, 0.6, 0.8]: - indices, indptr = bp.conn.FixedProb(prob, pre=shape[0], post=shape[1]).require('csr') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - for transpose in [True, False]: - print(f'shape = {shape}, prob = {prob}, transpose = {transpose}') - f_sparse = jit(lambda e: bm.event.csrmv(weight, indices, indptr, e, - shape=shape, transpose=transpose)) - f_jitconn = jit(lambda e: bm.jitconn.event_mv_prob_homo( - e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose)) - - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]).value < prob - f_sparse(events).block_until_ready() - f_jitconn(events).block_until_ready() - - t0 = time() - for _ in range(100): - f_sparse(events).block_until_ready() - print(f'csr sparse {time() - t0} s') - - t0 = time() - for _ in range(100): - f_jitconn(events).block_until_ready() - print(f'jit conn {time() - t0} s') - - print() - bm.clear_buffer_memory() - - -def compare_jitconn_imp(platform='gpu'): - bm.set_platform(platform) - - weight = 1. - seed = 1234 - - all_shapes = [ - (int(1e3), int(1e3)), - (int(1e3), int(1e4)), - (int(1e4), int(1e4)), - (int(5e4), int(5e4)), - (int(5e4), int(1e5)), - (int(5e5), int(1e5)), - (int(5e5), int(5e5)), - ] - - for shape in all_shapes: - for prob in [0.01, 0.05, 0.1, 0.2, 0.4, 0.8]: - for transpose in [True, False]: - print(f'shape = {shape}, prob = {prob}, transpose = {transpose}') - # f1 = jit(lambda e: event_matvec_prob_conn_homo_weight_v1( - # e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose)) - f2 = jit(lambda e: bm.jitconn.event_mv_prob_homo( - e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose)) - - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]).value < prob - # f1(events).block_until_ready() - f2(events).block_until_ready() - - # t0 = time() - # for _ in range(100): - # f1(events).block_until_ready() - # print(f'event_matvec_v1 {time() - t0} s') - - t0 = time() - for _ in range(100): - f2(events).block_until_ready() - print(f'event_matvec_v2 {time() - t0} s') - print() - bm.clear_buffer_memory() - - -if __name__ == '__main__': - pass - # compare_where('cpu') - # compare_sparse_ops('gpu') - # compare_jitconn_imp('gpu') diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py deleted file mode 100644 index 21a246650..000000000 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py +++ /dev/null @@ -1,573 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -seed = 1234 - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 - ] -types = [ - 'homo', - 'uniform', - 'normal' - ] -transpose = [ - True, - False - ] -outdim_parallel = [ - True, - False, - ] -bool_event = [ - True, - False - ] -conn_prob = 0.05 -homo_data = 1. -w_low = 0. -w_high = 1. -w_mu = 0. -w_sigma = 0.1 - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_event_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - return r - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_event_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.event_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - - -def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - - -def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) - elif _type == 'uniform': - return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) - elif _type == 'normal': - return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) - else: - raise ValueError - - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - for _bool_event in bool_event: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_event_matvec_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - for _bool_event in bool_event: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_event_matvec_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py deleted file mode 100644 index ff4f01afc..000000000 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py +++ /dev/null @@ -1,589 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') -# bm.disable_gpu_memory_preallocation() - -seed = 1234 - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 - ] -types = [ - 'homo', - 'uniform', - 'normal' - ] -transpose = [ - True, - False - ] -outdim_parallel = [ - True, - False, - ] -bool_event = [ - True, - False - ] -conn_prob = 0.05 -homo_data = 1. -w_low = 0. -w_high = 1. -w_mu = 0. -w_sigma = 0.1 - -print(bm.get_platform()) - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r +=jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)( - vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)( - vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)( - vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)( - vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)( - vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)( - vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) - elif _type == 'uniform': - return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) - elif _type == 'normal': - return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) - else: - raise ValueError - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - for _bool_event in bool_event: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - for _bool_event in bool_event: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py deleted file mode 100644 index 14a19aefb..000000000 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py +++ /dev/null @@ -1,560 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('gpu') - -seed = 1234 - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 - ] -types = [ - 'homo', - 'uniform', - 'normal' - ] -transpose = [ - True, - False - ] -outdim_parallel = [ - True, - False, - ] -bool_event = False -conn_prob = 0.05 -homo_data = 1. -w_low = 0. -w_high = 1. -w_mu = 0. -w_sigma = 0.1 - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += bm.jitconn.mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - return r - -def test_jitconn_matvec_homo(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec_normal(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -def test_jitconn_matvec(shape, _type, transpose, outdim_parallel): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo(shape, transpose, outdim_parallel) - elif _type == 'uniform': - return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel) - elif _type == 'normal': - return test_jitconn_matvec_normal(shape, transpose, outdim_parallel) - else: - raise ValueError - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py deleted file mode 100644 index 165c9b19b..000000000 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py +++ /dev/null @@ -1,736 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -seed = 1234 - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 - ] -bool_event = False -types = [ - 'homo', - 'uniform', - 'normal' - ] -transpose = [ - True, - False - ] -outdim_parallel = [ - True, - False, - ] -conn_prob = 0.05 -homo_data = 1. -w_low = 0. -w_high = 1. -w_mu = 0. -w_sigma = 0.1 - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)( - vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(4, 5, 6)) -def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)( - vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)( - vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)( - vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)( - vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -@partial(jax.jit, static_argnums=(5, 6, 7)) -def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)( - vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - return r - -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_2: ', brainpy_time2, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_4: ', brainpy_time4, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_2: ', brainpy_time2, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_4: ', brainpy_time4, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_2: ', brainpy_time2, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_4: ', brainpy_time4, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) - elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) - else: - raise ValueError - - -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) - elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) - else: - raise ValueError - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] - df.to_csv(f'{PATH}/jitconn_matvec_grad_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _type in types: - for _outdim_parallel in outdim_parallel: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) - # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] - df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py deleted file mode 100644 index e23bd5741..000000000 --- a/brainpy/_src/math/jitconn/tests/matmat_jitconn_performance.py +++ /dev/null @@ -1,61 +0,0 @@ -from time import time - -import brainpy.math as bm -import jax.numpy as jnp -from jax import jit, vmap - - -def compare_jitconn_imp(platform='gpu'): - bm.set_platform(platform) - - seed = 1234 - num_loop = 1 - - all_shapes = [ - # (int(1e3), int(1e3)), - # (int(1e3), int(1e4)), - # (int(1e4), int(1e4)), - # (int(5e4), int(5e4)), - # (int(5e4), int(1e5)), - # (int(5e5), int(1e5)), - (int(5e5), int(5e5)), - # (int(1e5), int(1e5)), - ] - - for m in [32, 64, 128, 256]: - for shape in all_shapes: - for prob in [0.01]: - print(f'm = {m}, shape = {shape}, prob = {prob}') - f1 = jit( - vmap(lambda a: bm.jitconn.mv_prob_normal( - a, w_mu=0., w_sigma=0.01, conn_prob=prob, shape=shape, seed=seed, transpose=True - )) - ) - f2 = jit(lambda e: bm.jitconn.mm_prob_normal( - e, w_mu=0., w_sigma=0.01, conn_prob=prob, shape=shape, seed=seed, version='v2' - )) - - rng = bm.random.RandomState() - mat = bm.as_jax(rng.random((m, shape[0]))) - r1 = f1(mat).block_until_ready() - r2 = f2(mat).block_until_ready() - assert r1.shape == r2.shape - print(jnp.allclose(r1, r2)) - - t0 = time() - for _ in range(num_loop): - f1(mat).block_until_ready() - print(f'matvec vmap {time() - t0} s') - - t0 = time() - for _ in range(num_loop): - f2(mat).block_until_ready() - print(f'matmat {time() - t0} s') - - print() - bm.clear_buffer_memory() - - -if __name__ == '__main__': - pass - compare_jitconn_imp('gpu') diff --git a/brainpy/_src/math/jitconn/tests/matmat_testcase.py b/brainpy/_src/math/jitconn/tests/matmat_testcase.py deleted file mode 100644 index cfd6e5369..000000000 --- a/brainpy/_src/math/jitconn/tests/matmat_testcase.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy.math as bm -import jax -import jax.numpy as jnp -from absl.testing import parameterized - -shapes = [(100, 200), - (200, 200), - (10, 1000), - (2, 1000), - (1000, 10), - (1000, 2)] - - -class Test_matmat_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform, **kwargs): - super(Test_matmat_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'shape = {shape}, ' - f'm={m}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - m=m, - seed=1234 - ) - for x64 in [True, False] - for shape in shapes - for prob in [0.01, 0.05, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for m in [5, 8, 15, 33] - ) - def test_uniform(self, shape, prob, w_low, w_high, m, seed=None, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'm = {m}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - matrix = bm.as_jax(rng.random((m, shape[0]))) - - r1 = bm.jitconn.matmat_prob_conn_uniform_weight(matrix, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - version='v1') - r2 = bm.jitconn.matmat_prob_conn_uniform_weight(matrix, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - version='v1') - self.assertTrue(jnp.allclose(r1, r2)) - - f = jax.vmap(lambda a: bm.jitconn.matvec_prob_conn_uniform_weight( - a, w_low=w_low, w_high=w_high, conn_prob=prob, shape=shape, seed=seed, transpose=True)) - r3 = f(matrix) - self.assertTrue(jnp.allclose(r1, r3)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'm={m}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234, - m=m, - ) - for x64 in [True, False] - for shape in shapes - for prob in [0.01, 0.05, 0.1, 0.2] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for m in [5, 8, 15, 33] - ) - def test_normal(self, shape, prob, w_mu, w_sigma, m, seed=None, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'm = {m}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - matrix = bm.as_jax(rng.random((m, shape[0]))) - - r1 = bm.jitconn.matmat_prob_conn_normal_weight(matrix, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed) - r2 = bm.jitconn.matmat_prob_conn_normal_weight(matrix, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed) - self.assertTrue(jnp.allclose(r1, r2)) - - f = jax.vmap( - lambda a: bm.jitconn.matvec_prob_conn_normal_weight( - a, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, shape=shape, seed=seed, transpose=True) - ) - r3 = f(matrix) - self.assertTrue(jnp.allclose(r1, r3)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py b/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py deleted file mode 100644 index ddeb30c21..000000000 --- a/brainpy/_src/math/jitconn/tests/matvec_jitconn_performance.py +++ /dev/null @@ -1,53 +0,0 @@ -from time import time - -import brainpy.math as bm -from jax import jit - - -def compare_jitconn_imp(platform='gpu'): - bm.set_platform(platform) - - weight = 1. - seed = 1234 - - all_shapes = [ - # (int(1e3), int(1e3)), - # (int(1e3), int(1e4)), - # (int(1e4), int(1e4)), - # (int(5e4), int(5e4)), - # (int(5e4), int(1e5)), - (int(5e5), int(1e5)), - (int(5e5), int(5e5)), - ] - - for shape in all_shapes: - for prob in [0.01, 0.05, 0.1, 0.2, 0.4, 0.8]: - for transpose in [True, False]: - print(f'shape = {shape}, prob = {prob}, transpose = {transpose}') - f1 = jit(lambda e: bm.jitconn.mv_prob_homo( - e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose)) - f2 = jit(lambda e: bm.jitconn.mv_prob_homo( - e, weight, conn_prob=prob, shape=shape, seed=seed, transpose=transpose)) - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - f1(events).block_until_ready() - f2(events).block_until_ready() - - t0 = time() - for _ in range(100): - f1(events).block_until_ready() - print(f'event_matvec_v1 {time() - t0} s') - - t0 = time() - for _ in range(100): - f2(events).block_until_ready() - print(f'event_matvec_v2 {time() - t0} s') - - print() - bm.clear_buffer_memory() - - -if __name__ == '__main__': - pass - compare_jitconn_imp('gpu') diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py deleted file mode 100644 index e2c91493a..000000000 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ /dev/null @@ -1,435 +0,0 @@ -# -*- coding: utf-8 -*- -import os - -import jax -import jax.numpy as jnp -import pytest -from absl.testing import parameterized - -import brainpy.math as bm - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -shapes = [(100, 200), (1000, 10)] - - -class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - homo_data=[-1.], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - bool_event=[True, False], - seed=[1234], - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: bm.jitconn.event_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: bm.jitconn.event_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - w_low=[-1.], - w_high=[1.], - bool_event=[True, False] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - bool_event=[True, False], - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: bm.jitconn.event_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: bm.jitconn.event_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1, ], - w_mu=[0.], - w_sigma=[0.1], - bool_event=[True, False], - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose = [True, False], - x64 = [True, False], - outdim_parallel = [True, False], - shape = shapes, - prob = [0.1], - bool_event = [True, False], - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose = [True, False], - x64 = [True, False], - outdim_parallel = [True, False], - shape = shapes, - prob = [0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py deleted file mode 100644 index b2fa77229..000000000 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py +++ /dev/null @@ -1,564 +0,0 @@ -# -*- coding: utf-8 -*- -from functools import partial - -import jax -import jax.numpy as jnp -from absl.testing import parameterized - -import platform -import brainpy.math as bm - -import pytest -pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) - -shapes = [(100, 200), - # (10, 1000), - (2, 1000), - # (1000, 10), - (1000, 2)] - -brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') -taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') -brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib') -taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi') -brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib') -taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi') - -class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: brainpylib_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - ) - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: brainpylib_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - ).sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3)) - self.assertTrue(jnp.allclose(r1 * 2., r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: brainpylib_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: brainpylib_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=None, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: brainpylib_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py deleted file mode 100644 index 0cf000715..000000000 --- a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py +++ /dev/null @@ -1,172 +0,0 @@ -# -*- coding: utf-8 -*- -import os - -import jax.numpy as jnp -import pytest -from absl.testing import parameterized - -import brainpy.math as bm -import platform - -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - - -shapes = [ - (2, 2), - # (1000, 10) -] - -SEED = 1234 - - -class TestGetHomoWeightMatrix(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestGetHomoWeightMatrix, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_get_homo_weight_matrix(self, transpose, outdim_parallel, shape, prob): - homo_data = 1. - print( - f'test_get_homo_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - shape = (shape[1], shape[0]) if transpose else shape - print(conn.shape) - assert conn.shape == shape - # assert conn.dtype == jnp.float_ - # sum all true values - print( - f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - - # compare with jitconn op - - print(f'conn: {conn}') - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=SEED, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = vector @ conn if transpose else conn @ vector - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - -class TestGetUniformWeightMatrix(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestGetUniformWeightMatrix, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.5], - w_low=[0.1], - w_high=[0.9], - ) - def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_low, w_high): - print( - f'test_get_uniform_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_low={w_low}, w_high={w_high}') - weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, SEED, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - shape = (shape[1], shape[0]) if transpose else shape - assert weight.shape == shape - assert weight.dtype == jnp.float32 - - weight_true = weight > 0. - - print( - f'jnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - - # compare with jitconn op - - print(f'weight: {weight}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=SEED, - outdim_parallel=outdim_parallel, - transpose=transpose) - - # r2 = weight @ events if transpose else events @ weight - r2 = events @ weight if transpose else weight @ events - print(f'r1: {r1}\n r2: {r2}') - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - -class TestGetNormalWeightMatrix(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestGetNormalWeightMatrix, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - w_mu=[0.0], - w_sigma=[1.0], - ) - def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_mu, w_sigma): - print( - f'test_get_normal_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_mu={w_mu}, w_sigma={w_sigma}') - weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, SEED, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - shape = (shape[1], shape[0]) if transpose else shape - assert weight.shape == shape - assert weight.dtype == jnp.float32 - - weight_true = weight > 0. - - print( - f'jnnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - - # compare with jitconn op - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_normal(vector, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=SEED, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = vector @ weight if transpose else weight @ vector - print(f'r1: {r1}\n r2: {r2}') - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py deleted file mode 100644 index d69def9a9..000000000 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ /dev/null @@ -1,403 +0,0 @@ -# -*- coding: utf-8 -*- -import os - -import jax -import jax.numpy as jnp -import pytest -from absl.testing import parameterized - -import brainpy.math as bm - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -shapes = [(100, 200), (1000, 10)] - - -class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - x64=[True, False], - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - homo_data=[1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - x64=[True, False], - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - w_low=[-0.1], - w_high=[1.0], - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - x64=[True, False], - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - w_mu=[0.], - w_sigma=[0.2] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: bm.jitconn.mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_old.py b/brainpy/_src/math/jitconn/tests/test_matvec_old.py deleted file mode 100644 index 360711e7b..000000000 --- a/brainpy/_src/math/jitconn/tests/test_matvec_old.py +++ /dev/null @@ -1,551 +0,0 @@ -# -*- coding: utf-8 -*- -from functools import partial - -import jax -import jax.numpy as jnp -from absl.testing import parameterized - -import brainpy.math as bm -import platform -import pytest - -pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -shapes = [(100, 200), - (10, 1000), - (2, 1000), - (1000, 10), - (1000, 2)] - -brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') -taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') -brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib') -taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi') -brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib') -taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi') - -class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - r2 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: brainpylib_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - ) - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: brainpylib_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: brainpylib_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: brainpylib_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - print('r1:', r1) - print('r2:', r2) - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py deleted file mode 100644 index d40a93247..000000000 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ /dev/null @@ -1,465 +0,0 @@ -import os -import time - -import numpy as np - -os.environ["CUDA_VISIBLE_DEVICES"] = "2" - -import jax -import jax.numpy as jnp -import taichi as ti -from jax.experimental.sparse import csr - -import brainpy as bp -import brainpy.math as bm - -bm.set_platform('gpu') - -size = [ - (100, 100, 100), - (100, 1000, 100), - (1000, 1000, 100), - (1000, 1000, 1000), - (100, 10000, 100), - (10000, 100, 1000), - (1000, 100, 10000), - (10000, 10000, 1000), - (10000, 1000, 10000), - (10000, 10000, 10000), - (20000, 20000, 20000), -] - -values_type = [ - 'homo', - # 'heter' -] -events_type = ['float'] -transpose = [ - # True, - False -] -ITERATION = 10 -SPARSITY = 0.05 - -print(bm.get_platform()) - - -@ti.kernel -def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - n = out.shape[1] - m = row_ptr.shape[0] - 1 - for j in range(n): # parallize along the n dimension - for row_i in range(m): # loop along the m dimension - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - -@ti.kernel -def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - m = row_ptr.shape[0] - 1 - n = matrix.shape[1] - for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - -@ti.kernel -def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - m, n = out.shape - for row_i, col_k in ti.ndrange(m, n): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -# transpose homo -_csr_matmat_transpose_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) - -# no transpose homo -_csr_matmat_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) - - -def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - matrix = bm.as_jax(matrix) - weight = jnp.atleast_1d(weight) - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - if transpose: - prim = _csr_matmat_transpose_homo_p - else: - prim = _csr_matmat_homo_p - r = prim(indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], - transpose=transpose, - shape=shape) - return r[0] - - -SHARED_MEM_SIZE = 256 - - -# @ti.kernel -# def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), -# row_ptr: ti.types.ndarray(ndim=1), -# matrix: ti.types.ndarray(ndim=2), -# out: ti.types.ndarray(ndim=2)): -# m, n = out.shape -# l = col_indices.shape[0] -# ti.loop_config(block_dim=SHARED_MEM_SIZE) -# # for i_col, i_row in ti.ndrange(n, m): -# for i in range(m * n): -# indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) -# -# # one block threads compute will SHARED_MEM_SIZE columns -# i_row = i // SHARED_MEM_SIZE -# i_col = i % SHARED_MEM_SIZE -# -# index_start = row_ptr[i_row] -# end_border = row_ptr[i_row + 1] -# n_share = (end_border - index_start) // SHARED_MEM_SIZE -# n_last = end_border - index_start - n_share * SHARED_MEM_SIZE -# -# r = 0. -# for i_share in range(n_share): -# indices_sm[i_col] = col_indices[i_col + i_share * SHARED_MEM_SIZE] -# ti.simt.block.sync() -# # compute -# for j in range(SHARED_MEM_SIZE): -# r += matrix[indices_sm[j], i_col] -# indices_sm[i_col] = col_indices[ti.min(i_col + n_share * SHARED_MEM_SIZE, l)] -# ti.simt.block.sync() -# for j in range(n_last): -# r += matrix[indices_sm[j], i_col] -# out[i_row, i_col] += r - - -@ti.kernel -def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - m, n = out.shape - l = col_indices.shape[0] - ti.loop_config(block_dim=SHARED_MEM_SIZE) - - indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) - # for i_col, i_row in ti.ndrange(n, m): - for i in ti.ndrange(n * m): - # i_col = ti.global_thread_idx() % n - # i_row = ti.global_thread_idx() // n - i_col = i % n - i_row = i // n - i_share = i_col % SHARED_MEM_SIZE - - index_start = row_ptr[i_row] - end_border = row_ptr[i_row + 1] - n_share = (end_border - index_start) // SHARED_MEM_SIZE - n_last = end_border - index_start - n_share * SHARED_MEM_SIZE - - r = 0. - for k in range(n_share): - indices_sm[i_share] = col_indices[index_start + i_share + k * SHARED_MEM_SIZE] - ti.simt.block.sync() - for j in range(SHARED_MEM_SIZE): - r += matrix[indices_sm[j], i_col] - indices_sm[i_share] = col_indices[ti.min(index_start + i_share + n_share * SHARED_MEM_SIZE, l)] - ti.simt.block.sync() - for j in range(n_last): - r += matrix[indices_sm[j], i_col] - - # final results - out[i_row, i_col] += r - - -# no transpose homo -_csr_matmat_homo2_p = bm.XLACustomOp(gpu_kernel=_csr_matmat_homo2) - - -def taichi_csrmm2(weight, indices, indptr, matrix, shape, transpose): - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - matrix = bm.as_jax(matrix) - weight = jnp.atleast_1d(weight) - result_shape = (shape[1] if transpose else shape[0], matrix.shape[1]) - return _csr_matmat_homo2_p(indices, indptr, matrix, transpose=transpose, shape=shape, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)])[0] - - -def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - matrix = bm.as_jax(matrix) - weight = jnp.atleast_1d(weight) - return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - - -def generate_op(op): - def csrmm(weight, indices, indptr, matrix, shape, transpose): - r = 0 - for i in range(ITERATION): - t = op(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - r += t - return r - - return jax.jit(csrmm, static_argnames=('shape', 'transpose')) - - -def run_spmm_homo(op, shape, transpose, use_heter_data=False): - bm.random.seed(1234) - matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) - matrix2_shape = (shape[1], shape[2]) - indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') - matrix = bm.as_jax(bm.random.random(matrix2_shape)) - weight = 1. - if use_heter_data: - weight = bm.ones(indices.shape) * weight - - result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - times = [] - for i in range(10): - time0 = time.time() - result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time1 = time.time() - times.append(time1 - time0) - return np.asarray(times).mean(), result - - -bm.clear_taichi_aot_caches() -for shape in size: - for _transpose in transpose: - cusparse_times, cusparse_r = run_spmm_homo(generate_op(jaxlib_csrmm), shape, _transpose, use_heter_data=True) - homo1_times, homo1_r = run_spmm_homo(generate_op(taichi_csrmm), shape, _transpose) - homo2_times, homo2_r = run_spmm_homo(generate_op(taichi_csrmm2), shape, _transpose) - print(jnp.allclose(cusparse_r, homo1_r), jnp.allclose(cusparse_r, homo2_r)) - print(f'shape={shape}, transpose={_transpose}, cusparse/homo1 = {cusparse_times / homo1_times}, ' - f'cusparse/homo2 = {cusparse_times / homo2_times}') - print(homo2_r) - -# def test_sparse_csrmm(shape, values_type, events_type, transpose): -# rng = bm.random.RandomState(seed=1234) -# matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) -# matrix2_shape = (shape[1], shape[2]) -# indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') -# matrix = rng.random(matrix2_shape) -# matrix = bm.as_jax(matrix) -# weight = 1. -# -# heter_data = bm.ones(indices.shape) * weight -# -# if events_type == 'float': -# matrix = matrix.astype(bm.float32) -# # if values_type == 'heter': -# # weight = heter_data -# -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# -# time0 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time1 = time.time() -# -# time2 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time3 = time.time() -# -# time4 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time5 = time.time() -# -# time6 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time7 = time.time() -# -# time8 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time9 = time.time() -# -# time10 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time11 = time.time() -# -# time12 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time13 = time.time() -# -# time14 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time15 = time.time() -# -# time16 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time17 = time.time() -# -# time18 = time.time() -# result = jax.block_until_ready( -# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time19 = time.time() -# -# result1 = result -# -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# -# time20 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time21 = time.time() -# -# result2 = result -# -# time22 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time23 = time.time() -# -# time24 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time25 = time.time() -# -# time26 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time27 = time.time() -# -# time28 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time29 = time.time() -# -# time30 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time31 = time.time() -# -# time32 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time33 = time.time() -# -# time34 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time35 = time.time() -# -# time36 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time37 = time.time() -# -# time38 = time.time() -# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) -# time39 = time.time() -# -# taichi_aot_time1 = (time1 - time0) * 1000 -# taichi_aot_time2 = (time3 - time2) * 1000 -# taichi_aot_time3 = (time5 - time4) * 1000 -# taichi_aot_time4 = (time7 - time6) * 1000 -# taichi_aot_time5 = (time9 - time8) * 1000 -# taichi_aot_time6 = (time11 - time10) * 1000 -# taichi_aot_time7 = (time13 - time12) * 1000 -# taichi_aot_time8 = (time15 - time14) * 1000 -# taichi_aot_time9 = (time17 - time16) * 1000 -# taichi_aot_time10 = (time19 - time18) * 1000 -# brainpy_time1 = (time21 - time20) * 1000 -# brainpy_time2 = (time23 - time22) * 1000 -# brainpy_time3 = (time25 - time24) * 1000 -# brainpy_time4 = (time27 - time26) * 1000 -# brainpy_time5 = (time29 - time28) * 1000 -# brainpy_time6 = (time31 - time30) * 1000 -# brainpy_time7 = (time33 - time32) * 1000 -# brainpy_time8 = (time35 - time34) * 1000 -# brainpy_time9 = (time37 - time36) * 1000 -# brainpy_time10 = (time39 - time38) * 1000 -# print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) -# print('taichi_aot_1: ', taichi_aot_time1, 'ms') -# print('taichi_aot_3: ', taichi_aot_time3, 'ms') -# print('taichi_aot_5: ', taichi_aot_time5, 'ms') -# print('taichi_aot_7: ', taichi_aot_time7, 'ms') -# print('taichi_aot_9: ', taichi_aot_time9, 'ms') -# print('brainpylib_1: ', brainpy_time1, 'ms') -# print('brainpylib_3: ', brainpy_time3, 'ms') -# print('brainpylib_5: ', brainpy_time5, 'ms') -# print('brainpylib_7: ', brainpy_time7, 'ms') -# print('brainpylib_9: ', brainpy_time9, 'ms') -# print(bm.allclose(result1, result2)) -# -# return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ -# taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ -# brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ -# brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -# PATH = os.path.dirname(os.path.abspath(__file__)) -# -# # init dataframe -# df = pd.DataFrame( -# columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', -# 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', -# 'taichi aot time5(ms)', -# 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', -# 'taichi aot time10(ms)', -# 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', -# 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) -# -# for shape in size: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ -# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ -# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, -# _values_type, -# _events_type, -# _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, -# _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, -# taichi_aot_time_5, -# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, -# taichi_aot_time_10, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, -# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] -# -# print(shape, _values_type, _events_type, _transpose) -# a = np.asarray([taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, -# taichi_aot_time_5, -# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, -# taichi_aot_time_10]) -# b = np.asarray([brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, -# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]) -# print(a) -# print(b) -# print(a.sum() / b.sum()) -# df.to_csv(f'{PATH}/csrmm_{bm.get_platform()}.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py b/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py deleted file mode 100644 index 8b4afe21e..000000000 --- a/brainpy/_src/math/sparse/tests/csr_matvec_VS_cusparse_csr_matvec.py +++ /dev/null @@ -1,668 +0,0 @@ - -# -*- coding: utf-8 -*- - -import time - -import brainpy as bp -import brainpy.math as bm -import numpy as np - -from brainpy._src.math.sparse import cusparse_bcsr_matvec -# from brainpy._src.math.sparse import cusparse_csr_matvec -from brainpy._src.math.sparse import csrmv -from scipy.sparse import csr_matrix - -def compare(platform='cpu'): - """ - - CPU - --- - - shape = (1000, 1000) - cuSPARSE 0.02663278579711914 s - brainpylib 0.028490781784057617 s - - shape = (1000, 10000) - cuSPARSE 0.06195855140686035 s - brainpylib 0.04008936882019043 s - - shape = (10000, 1000) - cuSPARSE 0.04706525802612305 s - brainpylib 0.04366803169250488 s - - shape = (10000, 10000) - cuSPARSE 0.1891341209411621 s - brainpylib 0.177717924118042 s - - shape = (100000, 10000) - cuSPARSE 1.3123579025268555 s - brainpylib 1.3357517719268799 s - - shape = (100000, 100000) - cuSPARSE 13.544525384902954 s - brainpylib 14.612009048461914 s - - - GPU - --- - shape = (1000, 1000) - cuSPARSE 0.04015922546386719 s - brainpylib 0.024152517318725586 s - - shape = (1000, 10000) - cuSPARSE 0.04857826232910156 s - brainpylib 0.15707015991210938 s - - shape = (10000, 1000) - cuSPARSE 0.04973483085632324 s - brainpylib 0.14293313026428223 s - - shape = (10000, 10000) - cuSPARSE 0.17399168014526367 s - brainpylib 0.17151856422424316 s - - shape = (100000, 10000) - cuSPARSE 0.5249958038330078 s - brainpylib 0.3427560329437256 s - - shape = (50000, 50000) - cuSPARSE 1.4121572971343994 s - brainpylib 0.9002335071563721 s - - shape = (100000, 50000) - cuSPARSE 2.697688341140747 s - brainpylib 1.6211459636688232 s - """ - - - bm.set_platform(platform) - - for shape in [ - (1000, 1000), - (1000, 10000), - (10000, 1000), - (10000, 10000), - (100000, 10000), - (50000, 50000), - (100000, 50000), - ]: - print(f'shape = {shape}') - - rng = bm.random.RandomState(123) - conn = bp.conn.FixedProb(0.1)(*shape) - indices, indptr = conn.require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - data = rng.random(indices.shape).value - vector = rng.random(shape[1]).value - - r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='cusparse') - r1.block_until_ready() - r2 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='vector') - r2.block_until_ready() - - t0 = time.time() - for _ in range(100): - r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='cusparse') - r1.block_until_ready() - print(f'cuSPARSE {time.time() - t0} s') - - t0 = time.time() - for _ in range(100): - r1 = bm.sparse.csrmv(data, indices, indptr, vector, shape=shape, method='vector') - r1.block_until_ready() - print(f'brainpylib {time.time() - t0} s') - print() - - - -def compare2(platform='cpu'): - """ - - CPU - --- - - shape = (1000, 1000) - cuSPARSE 0.02663278579711914 s - brainpylib 0.028490781784057617 s - - shape = (1000, 10000) - cuSPARSE 0.06195855140686035 s - brainpylib 0.04008936882019043 s - - shape = (10000, 1000) - cuSPARSE 0.04706525802612305 s - brainpylib 0.04366803169250488 s - - shape = (10000, 10000) - cuSPARSE 0.1891341209411621 s - brainpylib 0.177717924118042 s - - shape = (100000, 10000) - cuSPARSE 1.3123579025268555 s - brainpylib 1.3357517719268799 s - - shape = (100000, 100000) - cuSPARSE 13.544525384902954 s - brainpylib 14.612009048461914 s - - - GPU - --- - shape = (1000, 1000) - cuSPARSE 0.04015922546386719 s - brainpylib 0.024152517318725586 s - - shape = (1000, 10000) - cuSPARSE 0.04857826232910156 s - brainpylib 0.15707015991210938 s - - shape = (10000, 1000) - cuSPARSE 0.04973483085632324 s - brainpylib 0.14293313026428223 s - - shape = (10000, 10000) - cuSPARSE 0.17399168014526367 s - brainpylib 0.17151856422424316 s - - shape = (100000, 10000) - cuSPARSE 0.5249958038330078 s - brainpylib 0.3427560329437256 s - - shape = (50000, 50000) - cuSPARSE 1.4121572971343994 s - brainpylib 0.9002335071563721 s - - shape = (100000, 50000) - cuSPARSE 2.697688341140747 s - brainpylib 1.6211459636688232 s - """ - - bm.set_platform(platform) - p = 0.1 - - for shape in [ - (1000, 1000), - (1000, 10000), - (10000, 1000), - (10000, 10000), - (100000, 10000), - (50000, 50000), - (100000, 50000), - ]: - print(f'shape = {shape}') - - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(p)(*shape) - indices, indptr = conn.require('pre2post') - data = rng.random(indices.shape) - vector = rng.random(shape[1]) - - - - - bs_bsr = 16 - conn = bp.conn.FixedProb(p)(shape[0] // bs_bsr , shape[1] // bs_bsr) - indices_bsr, indptr_bsr = conn.require('pre2post') - data_bsr = rng.rand(len(indices_bsr)*bs_bsr, bs_bsr ) - shape_bsr = (shape[0] // bs_bsr, shape[1] // bs_bsr) - - # Mcsr = csr_matrix((data, indices, indptr), shape=shape) - # Mbsr = Mcsr.tobsr(blocksize=(8,8)) - # bs_bsr = 8 - # indices_bsr = Mbsr.indices - # indptr_bsr = Mbsr.indptr - # data_bsr_2 = Mbsr.data - # data_bsr = list(np.array(data_bsr_2).flatten()) - # indices_bsr = bm.as_jax(indices_bsr) - # indptr_bsr = bm.as_jax(indptr_bsr) - # data_bsr = bm.as_jax(data_bsr) - # shape_bsr = (shape[0]//bs_bsr,shape[1]//bs_bsr) - - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - r2 = csrmv(data, indices, indptr, vector, shape=shape) - r2.block_until_ready() - - # print(r1[980:1000]) - # print(r2[980:1000]) - # print(r3[900:1000]) - # print(len(indptr_bsr)) - # print(shape_bsr) - - t0 = time.time() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape_bsr) - r3.block_until_ready() - print(f'bsrSPARSE {time.time() - t0} s') - - # t0 = time.time() - # for _ in range(100): - # r3 = cusparse_bcsr_matvec(data_bsr, indices_bsr, indptr_bsr, vector, blocksize=bs_bsr,nnzb=len(indices_bsr), shape=shape) - # r3.block_until_ready() - # print(f'bsrSPARSE {time.time() - t0} s') - - - # t0 = time.time() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # print(f'cuSPARSE {time.time() - t0} s') - # t0 = time.time() - # for _ in range(100): - # r1 = cusparse_csr_matvec(data, indices, indptr, vector, shape=shape) - # r1.block_until_ready() - # print(f'cuSPARSE {time.time() - t0} s') - - t0 = time.time() - for _ in range(100): - r1 = csrmv(data, indices, indptr, vector, shape=shape) - r1.block_until_ready() - print(f'brainpylib {time.time() - t0} s') - print() - - bm.clear_buffer_memory() - - -if __name__ == '__main__': - compare('cpu') - # compare('gpu') diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py deleted file mode 100644 index 1db246212..000000000 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py +++ /dev/null @@ -1,250 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 -] - -values_type = [ - 'homo', - 'heter' - ] -events_type = ['float'] -transpose = [ - True, - False - ] -method = 'cusparse' - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -@partial(jax.jit, static_argnums=(4, 5)) -def csrmv_taichi(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] - return r - -@partial(jax.jit, static_argnums=(4, 5)) -def csrmv(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) - return r - -def test_sparse_csrmv(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmv_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmv_gpu.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py deleted file mode 100644 index d902c9395..000000000 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py +++ /dev/null @@ -1,273 +0,0 @@ -# from jax_taichi import jax_taichi_call - -import time -from functools import partial -import os - -import brainpy as bp -import brainpy.math as bm -import jax -import jax.numpy as jnp -import numpy as np -import pandas as pd -import taichi as ti - -bm.set_platform('cpu') - -s = [1000, - 5000, - 10000, - 15000, - 20000, - 25000, - 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] - -shape = [ - 1000, - 2500, - 5000, - 10000, - 25000, - 37500, - 50000 -] - -values_type = [ - 'homo', - 'heter' - ] -events_type = ['float'] -transpose = [ - True, - False - ] -method = 'cusparse' - -ITERATION = 100 -if bm.get_platform() == 'cpu': - ITERATION = 10 - -print(bm.get_platform()) - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - -@partial(jax.jit, static_argnums=(4, 5)) -def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - return r - -@partial(jax.jit, static_argnums=(4, 5)) -def csrmv_grad(weight, indices, indptr, vector, shape, transpose): - r = 0 - for i in range(ITERATION): - r += jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - return r - -def test_sparse_csrmv(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - time22 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) - - -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape1 in shape: - for shape2 in shape: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py deleted file mode 100644 index f7947089b..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ /dev/null @@ -1,293 +0,0 @@ -# -*- coding: utf-8 -*- - - -import os -from functools import partial - -import jax -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -# bm.set_platform('gpu') - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_csrmm(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmm, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones(indices.shape) * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_vmap(self, transpose, shape, homo_data): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones((10, indices.shape[0])) * homo_data - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=(shape[1], shape[0]) if transpose else ( - shape[0], shape[1])))(heter_data) - - # vmap 'data' - f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) - f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - - r1 = f1(dense) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() - if transpose else - ((dense * a) @ matrix).sum()), - argnums=0) - r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.sparse.csrmm))( - bm.asarray([homo_data]), indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() - if transpose else - ((dense * homo_data) @ m).sum()), - argnums=0) - r3 = dense_f2(matrix.astype(float)) - r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - bm.asarray([homo_data]), indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random(indices.shape)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - print(r2) - print(r1.shape, '-', r2.shape) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_vmap(self, transpose, shape): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=(shape[1], shape[0]) if transpose else ( - shape[0], shape[1])))(heter_data) - - f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) - f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r1 = jax.vmap(f1)(dense) - r2 = jax.vmap(f2)(heter_data) - - self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_grad(self, transpose, shape): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.as_jax(rng.random((indices.shape))) - dense = bm.sparse.csr_to_dense(heter_data, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - # matrix - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: ((a.T @ matrix).sum() - if transpose else - (a @ matrix).sum()), - argnums=0) - r1 = dense_f1(dense) - r2 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose - ) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - if transpose: - r1 = r1[cols, rows] - else: - r1 = r1[rows, cols] - print(r1 - r2) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad matrix - dense_f2 = jax.grad(lambda m: ((dense.T @ m).sum() - if transpose else - (dense @ m).sum())) - r3 = dense_f2(matrix) - r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose - ) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py deleted file mode 100644 index 61032cf25..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ /dev/null @@ -1,272 +0,0 @@ -# -*- coding: utf-8 -*- -import os -from functools import partial - -import jax -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -# Skip the test in Github Actions -IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0') -if IS_GITHUB_ACTIONS == '1': - pytest.skip('Skip the test in Github Actions', allow_module_level=True) - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def compare_with_nan_tolerance(a, b, tol=1e-8): - """ - Compare two arrays with tolerance for NaN values. - - Parameters: - a (np.array): First array to compare. - b (np.array): Second array to compare. - tol (float): Tolerance for comparing non-NaN elements. - - Returns: - bool: True if arrays are similar within the tolerance, False otherwise. - """ - if a.shape != b.shape: - return False - - # Create masks for NaNs in both arrays - nan_mask_a = bm.isnan(a) - nan_mask_b = bm.isnan(b) - - # Check if NaN positions are the same in both arrays - if not bm.array_equal(nan_mask_a, nan_mask_b): - return False - - # Compare non-NaN elements - a_non_nan = a[~nan_mask_a] - b_non_nan = b[~nan_mask_b] - - return bm.allclose(a_non_nan, b_non_nan, atol=tol) - - -class Test_csrmv_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (10, 1000)], - homo_data=[1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - conn = bp.conn.FixedProb(0.3) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(seed=seed) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones(indices.shape).value * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = bm.sparse.csrmv(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (100, 1000)], - v=[1.] - ) - def test_homo_vmap(self, transpose, shape, v): - print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(dense_data) - r2 = jax.vmap(f2)(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (10, 1000)], - homo_data=[1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # print('grad data start') - # grad 'data' - dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() - if transpose else - ((dense * a) @ vector).sum()), - argnums=0) - r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.sparse.csrmv))(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - # print('grad vector start') - # grad 'vector' - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) - r3 = dense_f2(vector) - r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() - if transpose else - ((dense * a) @ v).sum()), - argnums=(0, 1)) - r5 = dense_f3(homo_data, vector) - r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( - bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (2, 2000)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose = {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.as_jax(rng.random(indices.shape)) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = bm.sparse.csrmv(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - - self.assertTrue(compare_with_nan_tolerance(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(dense_data) - r2 = jax.vmap(f2)(heter_data) - self.assertTrue(compare_with_nan_tolerance(r1, r2)) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # grad 'data' - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) - csr_f1 = jax.grad(lambda a: bm.sparse.csrmv(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), - argnums=0) - r1 = csr_f1(heter_data) - r2 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r2 = r2[rows, cols] - print(r1.shape, r2.shape) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'vector' - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) - csr_f2 = jax.grad(lambda v: bm.sparse.csrmv(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) - r3 = dense_f2(vector) - r4 = csr_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 562c1cc18..624ade1b7 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -43,10 +43,7 @@ from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr -from brainpy._src.dependency_check import import_taichi, import_numba -import_taichi(error_if_not_found=False) -import_numba(error_if_not_found=False) __deprecations = { "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.",